Skip to content

Commit

Permalink
Merge branch 'main' into randn
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 27, 2024
2 parents ab2e59b + 1261279 commit 988de6b
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 3 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ LLVM = "6.1, 7, 8, 9"
LogExpFunctions = "0.3"
ObjectFile = "0.4"
Preferences = "1.4"
SparseArrays = "1"
SpecialFunctions = "1, 2"
StaticArrays = "1"
julia = "1.10"
Expand Down
22 changes: 21 additions & 1 deletion lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,27 @@ end
DuplicatedNoNeed(x, ∂f_∂x)
Like [`Duplicated`](@ref), except also specifies that Enzyme may avoid computing
the original result and only compute the derivative values.
the original result and only compute the derivative values. This creates opportunities
for improved performance.
```julia
function square_byref(out, v)
out[] = v * v
nothing
end
out = Ref(0.0)
dout = Ref(1.0)
Enzyme.autodiff(Reverse, square_byref, DuplicatedNoNeed(out, dout), Active(1.0))
dout[]
# output
0.0
```
For example, marking the out variable as `DuplicatedNoNeed` instead of `Duplicated` allows
Enzyme to avoid computing `v * v` (while still computing its derivative).
This should only be used if `x` is a write-only variable. Otherwise, if the differentiated
function stores values in `x` and reads them back in subsequent computations, using
Expand Down
1 change: 1 addition & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient!
export markType, batch_size, onehot, chunkedonehot

using LinearAlgebra
import SparseArrays
import EnzymeCore: ReverseMode, ReverseModeSplit, ForwardMode, Mode

import EnzymeCore: EnzymeRules
Expand Down
11 changes: 10 additions & 1 deletion src/absint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -432,13 +432,22 @@ function abs_typeof(
lasti = i
end
end
if !seen && fieldcount(typ) > 0
offset = offset - fieldoffset(typ, lasti)
typ = fieldtype(typ, lasti)
@assert Base.isconcretetype(typ)
if !Base.allocatedinline(typ)
legal = false
end
seen = true
end
if !seen
legal = false
end
end

typ2 = typ
while should_recurse(typ2, value_type(arg), byref, dl)
while legal && should_recurse(typ2, value_type(arg), byref, dl)
idx, _ = first_non_ghost(typ2)
if idx != -1
typ2 = fieldtype(typ2, idx)
Expand Down
127 changes: 127 additions & 0 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,133 @@ function EnzymeRules.reverse(
return (nothing, nothing)
end


function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig,
func::Const{typeof(LinearAlgebra.mul!)},
::Type{RT},
C::Annotation{<:StridedVecOrMat},
A::Const{<:SparseArrays.SparseMatrixCSCUnion},
B::Annotation{<:StridedVecOrMat},
α::Annotation{<:Number},
β::Annotation{<:Number}
) where {RT}

cache_C = !(isa(β, Const)) ? copy(C.val) : nothing
# Always need to do forward pass otherwise primal may not be correct
func.val(C.val, A.val, B.val, α.val, β.val)

primal = if EnzymeRules.needs_primal(config)
C.val
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
C.dval
else
nothing
end

# Check if A is overwritten and B is active (and thus required)
cache_A = ( EnzymeRules.overwritten(config)[5]
&& !(typeof(B) <: Const)
&& !(typeof(C) <: Const)
) ? copy(A.val) : nothing

# cache_B = ( EnzymeRules.overwritten(config)[6]) ? copy(B.val) : nothing

if !isa(α, Const)
cache_α = A.val*B.val
else
cache_α = nothing
end

cache = (cache_C, cache_A, cache_α)

return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfig,
func::Const{typeof(LinearAlgebra.mul!)},
::Type{RT}, cache,
C::Annotation{<:StridedVecOrMat},
A::Const{<:SparseArrays.SparseMatrixCSCUnion},
B::Annotation{<:StridedVecOrMat},
α::Annotation{<:Number},
β::Annotation{<:Number}
) where {RT}

cache_C, cache_A, cache_α = cache
Cval = !isnothing(cache_C) ? cache_C : C.val
Aval = !isnothing(cache_A) ? cache_A : A.val
# Bval = !isnothing(cache_B) ? cache_B : B.val

N = EnzymeRules.width(config)
if !isa(C, Const)
dCs = C.dval
dBs = isa(B, Const) ? dCs : B.dval

= if !isa(α, Const)
if N == 1
LinearAlgebra.dot(C.dval, cache_α)
else
ntuple(Val(N)) do i
Base.@_inline_meta
LinearAlgebra.dot(C.dval[i], cache_α)
end
end
else
nothing
end

= if !isa(β, Const)
if N == 1
LinearAlgebra.dot(C.dval, Cval)
else
ntuple(Val(N)) do i
Base.@_inline_meta
LinearAlgebra.dot(C.dval[i], Cval)
end
end
else
nothing
end

for i in 1:N
# This rule is incorrect since you need to project dA to have the same
# sparsity pattern as A.
# if !isa(A, Const)
# dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b]
# #dA .+= α*dC*B'
# mul!(dA, dC, Bval', α.val, true)
# end

if !isa(B, Const)
#dB .+= α*A'*dC
if N ==1
func.val(dBs, Aval', dCs, α.val, true)
else
func.val(dBs[i], Aval', dCs[i], α.val, true)
end
end

if N==1
dCs .*= β.val
else
dCs[i] .*= β.val
end
end
end

return (nothing, nothing, nothing, dα, dβ)
end







function EnzymeRules.forward(
config::EnzymeRules.FwdConfig,
::Const{typeof(sort!)},
Expand Down
42 changes: 42 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -677,4 +677,46 @@ end
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),)
end

@testset "SparseArrays spmatvec reverse rule" begin
C = zeros(18)
M = sprand(18, 9, 0.1)
v = randn(9)
α = 2.0
β = 1.0

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
in (Const, Active), Tβ in (Const, Active)

are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ))

end


for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tret, Tv) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
end
end

@testset "SparseArrays spmatmat reverse rule" begin
C = zeros(18, 11)
M = sprand(18, 9, 0.1)
v = randn(9, 11)
α = 2.0
β = 1.0

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
in (Const, Active), Tβ in (Const, Active)

are_activities_compatible(Tret, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ))
end

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tv) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
end
end

end # InternalRules
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4101,4 +4101,4 @@ include("ext/logexpfunctions.jl")

@testset "BFloat16s ext" begin
include("ext/bfloat16s.jl")
end
end

0 comments on commit 988de6b

Please sign in to comment.