From 77a8ccf47310ba95fc1a872911e08891462ce46e Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 27 Sep 2024 20:51:04 -0500 Subject: [PATCH 1/3] Optimize active only rev grad --- Project.toml | 2 + src/Enzyme.jl | 129 +++++++++++++++++++++++++++++++------------------- 2 files changed, 83 insertions(+), 48 deletions(-) diff --git a/Project.toml b/Project.toml index 0a03500ac9..2f72e8e4b7 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["William Moses ", "Valentin Churavy Date: Fri, 27 Sep 2024 20:51:37 -0500 Subject: [PATCH 2/3] Update Project.toml --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 2f72e8e4b7..0a03500ac9 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ authors = ["William Moses ", "Valentin Churavy Date: Fri, 27 Sep 2024 20:59:55 -0500 Subject: [PATCH 3/3] add makezero s/marray --- ext/EnzymeStaticArraysExt.jl | 7 ++ src/Enzyme.jl | 129 +++++++++++++---------------------- 2 files changed, 55 insertions(+), 81 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index bcaa3ec6cb..b751c336a2 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -23,4 +23,11 @@ end end end +@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray} + return Base.zero(x) +end +@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray} + return Base.zero(x) +end + end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7901d6cb3c..091021cb8b 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1577,29 +1577,12 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) x::ty_0, args::Vararg{Any,N}, ) where {F,ty_0,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten,N} - toemit = Expr[] - Actives = Bool[] - - if x <: Enzyme.Const - push!(toemit, quote - act_0 = false - end) - push!(Actives, false) - elseif Compiler.active_reg_inner(x, (), nothing) == Compiler.ActiveState - push!(toemit, quote - act_0 = false - end) - push!(Actives, true) - else - push!(toemit, quote - act_0 = - !(x isa Enzyme.Const) && - Compiler.active_reg_inner(Core.Typeof(x), (), nothing, Val(true)) == - Compiler.ActiveState #=justActive=# - end) - push!(Actives, false) - end - + toemit = Expr[quote + act_0 = + !(x isa Enzyme.Const) && + Compiler.active_reg_inner(Core.Typeof(x), (), nothing, Val(true)) == + Compiler.ActiveState #=justActive=# + end] rargs = Union{Symbol,Expr}[:x] acts = Symbol[Symbol("act_0")] @@ -1610,71 +1593,55 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) push!(rargs, argidx) sym = Symbol("act_$i") push!(acts, sym) - - if args[i] <: Enzyme.Const - push!(toemit, quote - $sym = false - end) - push!(Actives, false) - elseif Compiler.active_reg_inner(x, (), nothing) == Compiler.ActiveState - push!(toemit, quote - $sym = false - end) - push!(Actives, true) - else - push!(toemit, quote + push!( + toemit, + quote $sym = - !(x isa Enzyme.Const) && - Compiler.active_reg_inner(Core.Typeof($argidx), (), nothing, Val(true)) == - Compiler.ActiveState #=justActive=# - end) - push!(Actives, false) - end + !($argidx isa Enzyme.Const) && + Compiler.active_reg_inner( + Core.Typeof($argidx), + (), + nothing, + Val(true), + ) == Compiler.ActiveState #=justActive=# + end, + ) end idx = 0 shadows = Symbol[] enz_args = Expr[] resargs = Expr[] - for (arg, act, fact) in zip(rargs, acts, Actives) - if fact - push!(enz_args, quote - Active($arg) - end) - push!(resargs, quote - res[1][$(idx+1)] - end) - else - shad = Symbol("shad_$idx") - push!(shadows, shad) - push!(toemit, quote - $shad = if $arg isa Enzyme.Const - nothing - elseif $act - Ref(make_zero($arg)) - else - make_zero($arg) - end - end) - push!(enz_args, quote - if $arg isa Enzyme.Const - $arg - elseif $act - MixedDuplicated($arg, $shad) - else - Duplicated($arg, $shad) - end - end) - push!(resargs, quote - if $arg isa Enzyme.Const - nothing - elseif $act - $shad[] - else - $shad - end - end) - end + for (arg, act) in zip(rargs, acts) + shad = Symbol("shad_$idx") + push!(shadows, shad) + push!(toemit, quote + $shad = if $arg isa Enzyme.Const + nothing + elseif $act + Ref(make_zero($arg)) + else + make_zero($arg) + end + end) + push!(enz_args, quote + if $arg isa Enzyme.Const + $arg + elseif $act + MixedDuplicated($arg, $shad) + else + Duplicated($arg, $shad) + end + end) + push!(resargs, quote + if $arg isa Enzyme.Const + nothing + elseif $act + $shad[] + else + $shad + end + end) idx += 1 end push!(toemit, quote