diff --git a/Project.toml b/Project.toml index fd0882e97e..9623428c29 100644 --- a/Project.toml +++ b/Project.toml @@ -42,6 +42,7 @@ LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" ObjectFile = "0.4" Preferences = "1.4" +SparseArrays = "1" SpecialFunctions = "1, 2" StaticArrays = "1" julia = "1.10" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index b49c3738f6..091021cb8b 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -98,6 +98,7 @@ export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient! export markType, batch_size, onehot, chunkedonehot using LinearAlgebra +import SparseArrays import EnzymeCore: ReverseMode, ReverseModeSplit, ForwardMode, Mode import EnzymeCore: EnzymeRules diff --git a/src/internal_rules.jl b/src/internal_rules.jl index f8c6e730bb..b6d081d57d 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -724,6 +724,133 @@ function EnzymeRules.reverse( return (nothing, nothing) end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, + func::Const{typeof(LinearAlgebra.mul!)}, + ::Type{RT}, + C::Annotation{<:StridedVecOrMat}, + A::Const{<:SparseArrays.SparseMatrixCSCUnion}, + B::Annotation{<:StridedVecOrMat}, + α::Annotation{<:Number}, + β::Annotation{<:Number} + ) where {RT} + + cache_C = !(isa(β, Const)) ? copy(C.val) : nothing + # Always need to do forward pass otherwise primal may not be correct + func.val(C.val, A.val, B.val, α.val, β.val) + + primal = if EnzymeRules.needs_primal(config) + C.val + else + nothing + end + + shadow = if EnzymeRules.needs_shadow(config) + C.dval + else + nothing + end + + # Check if A is overwritten and B is active (and thus required) + cache_A = ( EnzymeRules.overwritten(config)[5] + && !(typeof(B) <: Const) + && !(typeof(C) <: Const) + ) ? copy(A.val) : nothing + + # cache_B = ( EnzymeRules.overwritten(config)[6]) ? copy(B.val) : nothing + + if !isa(α, Const) + cache_α = A.val*B.val + else + cache_α = nothing + end + + cache = (cache_C, cache_A, cache_α) + + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, + func::Const{typeof(LinearAlgebra.mul!)}, + ::Type{RT}, cache, + C::Annotation{<:StridedVecOrMat}, + A::Const{<:SparseArrays.SparseMatrixCSCUnion}, + B::Annotation{<:StridedVecOrMat}, + α::Annotation{<:Number}, + β::Annotation{<:Number} + ) where {RT} + + cache_C, cache_A, cache_α = cache + Cval = !isnothing(cache_C) ? cache_C : C.val + Aval = !isnothing(cache_A) ? cache_A : A.val + # Bval = !isnothing(cache_B) ? cache_B : B.val + + N = EnzymeRules.width(config) + if !isa(C, Const) + dCs = C.dval + dBs = isa(B, Const) ? dCs : B.dval + + dα = if !isa(α, Const) + if N == 1 + LinearAlgebra.dot(C.dval, cache_α) + else + ntuple(Val(N)) do i + Base.@_inline_meta + LinearAlgebra.dot(C.dval[i], cache_α) + end + end + else + nothing + end + + dβ = if !isa(β, Const) + if N == 1 + LinearAlgebra.dot(C.dval, Cval) + else + ntuple(Val(N)) do i + Base.@_inline_meta + LinearAlgebra.dot(C.dval[i], Cval) + end + end + else + nothing + end + + for i in 1:N + # This rule is incorrect since you need to project dA to have the same + # sparsity pattern as A. + # if !isa(A, Const) + # dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] + # #dA .+= α*dC*B' + # mul!(dA, dC, Bval', α.val, true) + # end + + if !isa(B, Const) + #dB .+= α*A'*dC + if N ==1 + func.val(dBs, Aval', dCs, α.val, true) + else + func.val(dBs[i], Aval', dCs[i], α.val, true) + end + end + + if N==1 + dCs .*= β.val + else + dCs[i] .*= β.val + end + end + end + + return (nothing, nothing, nothing, dα, dβ) +end + + + + + + + function EnzymeRules.forward( config::EnzymeRules.FwdConfig, ::Const{typeof(sort!)}, @@ -1269,4 +1396,4 @@ function EnzymeRules.reverse( smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, ) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} return (nothing, nothing, nothing) -end +end \ No newline at end of file diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 0d5bbdae01..246929272b 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -677,4 +677,46 @@ end # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),) end +@testset "SparseArrays spmatvec reverse rule" begin + C = zeros(18) + M = sprand(18, 9, 0.1) + v = randn(9) + α = 2.0 + β = 1.0 + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + Tα in (Const, Active), Tβ in (Const, Active) + + are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + + end + + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tret, Tv) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) + end +end + +@testset "SparseArrays spmatmat reverse rule" begin + C = zeros(18, 11) + M = sprand(18, 9, 0.1) + v = randn(9, 11) + α = 2.0 + β = 1.0 + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + Tα in (Const, Active), Tβ in (Const, Active) + + are_activities_compatible(Tret, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + end + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tv) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) + end +end + end # InternalRules diff --git a/test/runtests.jl b/test/runtests.jl index d499febd77..bd1c7dd90d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4101,4 +4101,4 @@ include("ext/logexpfunctions.jl") @testset "BFloat16s ext" begin include("ext/bfloat16s.jl") -end +end \ No newline at end of file