Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP invoke fix #1169

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f.val, tt)
if !allocatedinline(rt) || rt isa Union
forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI)
forward, adjoint = Enzyme.Compiler.thunk(nothing, Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI)
res = forward(f, args′...)
tape = res[1]
if ReturnPrimal
Expand All @@ -206,7 +206,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed || A <: BatchDuplicatedFunc
throw(ErrorException("Duplicated Returns not yet handled"))
end
thunk = Enzyme.Compiler.thunk(Val(world), FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
thunk = Enzyme.Compiler.thunk(nothing, Val(world), FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
if A <: Active
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f.val, tt)
Expand Down Expand Up @@ -319,7 +319,7 @@ f(x) = x*x
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
world = codegen_world_age(Core.Typeof(f.val), tt)

thunk = Enzyme.Compiler.thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width),
thunk = Enzyme.Compiler.thunk(nothing, Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width),
ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI)
thunk(f, args′...)
end
Expand Down Expand Up @@ -521,7 +521,7 @@ result, ∂v, ∂A
if !(A <: Const)
@assert ReturnShadow
end
Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
Enzyme.Compiler.thunk(nothing, Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
end

"""
Expand Down Expand Up @@ -584,7 +584,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated

world = codegen_world_age(eltype(FA), tt)

Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI)
Enzyme.Compiler.thunk(nothing, Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI)
end

@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI}
Expand All @@ -610,7 +610,7 @@ end

primal_tt = Tuple{map(eltype, args)...}
world = codegen_world_age(eltype(FA), primal_tt)
nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
nondef = Enzyme.Compiler.thunk(nothing, Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
TapeType = EnzymeRules.tape_type(nondef[1])
return TapeType
end
Expand Down Expand Up @@ -684,7 +684,7 @@ result, ∂v, ∂A
world = codegen_world_age(eltype(FA), primal_tt)

# TODO this assumes that the thunk here has the correct parent/etc things for getting the right cuda instructions -> same caching behavior
nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
nondef = Enzyme.Compiler.thunk(nothing, Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
TapeType = EnzymeRules.tape_type(nondef[1])
A2 = Compiler.return_type(typeof(nondef[1]))

Expand Down Expand Up @@ -995,15 +995,15 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2))
ModifiedBetween = Val((false, false))
FA = Const{Core.Typeof(f)}
World = Val(nothing)
primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)
primal, adjoint = Enzyme.Compiler.thunk(nothing, Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)

if num * chunk == n_out_val
last_size = chunk
primal2, adjoint2 = primal, adjoint
else
last_size = n_out_val - (num-1)*chunk
tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}}
primal2, adjoint2 = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)
primal2, adjoint2 = Enzyme.Compiler.thunk(nothing, Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)
end

tmp = ntuple(num) do i
Expand Down Expand Up @@ -1034,7 +1034,7 @@ end
rt = Core.Compiler.return_type(f, tt)
ModifiedBetween = Val((false, false))
FA = Const{Core.Typeof(f)}
primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)
primal, adjoint = Enzyme.Compiler.thunk(nothing, Val(world), FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)
rows = ntuple(n_outs) do i
Base.@_inline_meta
dx = zero(x)
Expand Down
36 changes: 19 additions & 17 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5298,10 +5298,8 @@ end
@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated
@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed

@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI}
@inline function thunkbase(mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI}
JuliaContext() do ctx
mi = fspec(eltype(FA), TT, World)

target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI)
tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)
Expand Down Expand Up @@ -5352,30 +5350,34 @@ end
TapeType = compile_result.TapeType
AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal), TapeType}
AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, TapeType}
return quote
Base.@_inline_meta
augmented = $AugT($(compile_result.primal))
adjoint = $AdjT($(compile_result.adjoint))
(augmented, adjoint)
end
augmented = AugT((compile_result.primal))
adjoint = AdjT((compile_result.adjoint))
return (augmented, adjoint)
elseif Mode == API.DEM_ReverseModeCombined
CAdjT = CombinedAdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)}
return quote
Base.@_inline_meta
$CAdjT($(compile_result.adjoint))
end
return CAdjT(compile_result.adjoint)
elseif Mode == API.DEM_ForwardMode
FMT = ForwardModeThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)}
return quote
Base.@_inline_meta
$FMT($(compile_result.adjoint))
end
return FMT(compile_result.adjoint)
else
@assert false
end
end
end

@inline function thunk(mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI}
return thunkbase(mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI)
end

@inline @generated function thunk(::Nothing, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI}
mi = fspec(eltype(FA), TT, World)
res = thunkbase(mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI)
return quote
Base.@_inline_meta
return $(res)
end
end

import GPUCompiler: deferred_codegen_jobs

@generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{TT}, ::Val{A},::Val{Mode},
Expand Down
4 changes: 2 additions & 2 deletions src/rules/activityrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function julia_activity_rule(f::LLVM.Function)

# Unsupported calling conv
# also wouldn't have any type info for this [would for earlier args though]
if mi.specTypes.parameters[end] === Vararg{Any}
if Base.isvarargtype(mi.specTypes.parameters[end])
return
end

Expand Down Expand Up @@ -71,4 +71,4 @@ function julia_activity_rule(f::LLVM.Function)
push!(return_attributes(f), StringAttribute("enzyme_inactive"))
end
end
end
end
Loading
Loading