Skip to content

Commit

Permalink
Add reverse rule for Sparse dense matmul/vec (#1792)
Browse files Browse the repository at this point in the history
* Add sparse array internal rule

* Add sparsearray extension for mul!

* Add more testing

* Add BatchDuplicated (still broken)

* Remove BatchMode since it isn't applicable?

* Add sparse array testing

* Don't support batchmode for now

* Revert project to old style

* Add sparse array compat bound

* reenable batch mode for bug hunting

* Turn on BatchDuplicated stuff again

* Remove Q comment

* Encorporate BatchDuplicated into testing properly

* Consider constant fp in runtime activity (#1797)

* Consider constant fp in runtime activity

* fix

* Suggest workaround in error for overwritten active by ref (#1791)

* Fix custom active reverse mode check (#1798)

* Remove Q comment

* Encorporate BatchDuplicated into testing properly

* Look for more writebarrier opportunities (#1800)

* Look for more writebarrier opportunities

* Update compiler.jl

* Restrict version to 1.10+ (#1809)

* Restrict version to 1.10+

* fix

* fixup

* Update CI.yml

* Update Project.toml

* Update Project.toml

* Update Project.toml

* Fix MixedDuplicated ABI error on primalerror (#1815)

* Update test

* Move new SparseArrays Cholmod into extension

* Make LinearAlgebra.mul! explicit

* Make sparse arrays not a extension

* Fix rules for 0.13

* Remove sparse arrays ext file

* Update compiler

---------

Co-authored-by: William Moses <[email protected]>
Co-authored-by: Daniel Wennberg <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent d092d4a commit 1261279
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 2 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ LLVM = "6.1, 7, 8, 9"
LogExpFunctions = "0.3"
ObjectFile = "0.4"
Preferences = "1.4"
SparseArrays = "1"
SpecialFunctions = "1, 2"
StaticArrays = "1"
julia = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient!
export markType, batch_size, onehot, chunkedonehot

using LinearAlgebra
import SparseArrays
import EnzymeCore: ReverseMode, ReverseModeSplit, ForwardMode, Mode

import EnzymeCore: EnzymeRules
Expand Down
129 changes: 128 additions & 1 deletion src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,133 @@ function EnzymeRules.reverse(
return (nothing, nothing)
end


function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig,
func::Const{typeof(LinearAlgebra.mul!)},
::Type{RT},
C::Annotation{<:StridedVecOrMat},
A::Const{<:SparseArrays.SparseMatrixCSCUnion},
B::Annotation{<:StridedVecOrMat},
α::Annotation{<:Number},
β::Annotation{<:Number}
) where {RT}

cache_C = !(isa(β, Const)) ? copy(C.val) : nothing
# Always need to do forward pass otherwise primal may not be correct
func.val(C.val, A.val, B.val, α.val, β.val)

primal = if EnzymeRules.needs_primal(config)
C.val
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
C.dval
else
nothing
end

# Check if A is overwritten and B is active (and thus required)
cache_A = ( EnzymeRules.overwritten(config)[5]
&& !(typeof(B) <: Const)
&& !(typeof(C) <: Const)
) ? copy(A.val) : nothing

# cache_B = ( EnzymeRules.overwritten(config)[6]) ? copy(B.val) : nothing

if !isa(α, Const)
cache_α = A.val*B.val
else
cache_α = nothing
end

cache = (cache_C, cache_A, cache_α)

return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfig,
func::Const{typeof(LinearAlgebra.mul!)},
::Type{RT}, cache,
C::Annotation{<:StridedVecOrMat},
A::Const{<:SparseArrays.SparseMatrixCSCUnion},
B::Annotation{<:StridedVecOrMat},
α::Annotation{<:Number},
β::Annotation{<:Number}
) where {RT}

cache_C, cache_A, cache_α = cache
Cval = !isnothing(cache_C) ? cache_C : C.val
Aval = !isnothing(cache_A) ? cache_A : A.val
# Bval = !isnothing(cache_B) ? cache_B : B.val

N = EnzymeRules.width(config)
if !isa(C, Const)
dCs = C.dval
dBs = isa(B, Const) ? dCs : B.dval

= if !isa(α, Const)
if N == 1
LinearAlgebra.dot(C.dval, cache_α)
else
ntuple(Val(N)) do i
Base.@_inline_meta
LinearAlgebra.dot(C.dval[i], cache_α)
end
end
else
nothing
end

= if !isa(β, Const)
if N == 1
LinearAlgebra.dot(C.dval, Cval)
else
ntuple(Val(N)) do i
Base.@_inline_meta
LinearAlgebra.dot(C.dval[i], Cval)
end
end
else
nothing
end

for i in 1:N
# This rule is incorrect since you need to project dA to have the same
# sparsity pattern as A.
# if !isa(A, Const)
# dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b]
# #dA .+= α*dC*B'
# mul!(dA, dC, Bval', α.val, true)
# end

if !isa(B, Const)
#dB .+= α*A'*dC
if N ==1
func.val(dBs, Aval', dCs, α.val, true)
else
func.val(dBs[i], Aval', dCs[i], α.val, true)
end
end

if N==1
dCs .*= β.val
else
dCs[i] .*= β.val
end
end
end

return (nothing, nothing, nothing, dα, dβ)
end







function EnzymeRules.forward(
config::EnzymeRules.FwdConfig,
::Const{typeof(sort!)},
Expand Down Expand Up @@ -1269,4 +1396,4 @@ function EnzymeRules.reverse(
smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}},
) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}}
return (nothing, nothing, nothing)
end
end
42 changes: 42 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -677,4 +677,46 @@ end
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),)
end

@testset "SparseArrays spmatvec reverse rule" begin
C = zeros(18)
M = sprand(18, 9, 0.1)
v = randn(9)
α = 2.0
β = 1.0

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
in (Const, Active), Tβ in (Const, Active)

are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ))

end


for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tret, Tv) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
end
end

@testset "SparseArrays spmatmat reverse rule" begin
C = zeros(18, 11)
M = sprand(18, 9, 0.1)
v = randn(9, 11)
α = 2.0
β = 1.0

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
in (Const, Active), Tβ in (Const, Active)

are_activities_compatible(Tret, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ))
end

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tv) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
end
end

end # InternalRules
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4101,4 +4101,4 @@ include("ext/logexpfunctions.jl")

@testset "BFloat16s ext" begin
include("ext/bfloat16s.jl")
end
end

0 comments on commit 1261279

Please sign in to comment.