Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reverse rule for Sparse dense matmul/vec #1792

Merged
merged 36 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ab1807c
Add sparse array internal rule
ptiede Sep 4, 2024
8a35c83
Add sparsearray extension for mul!
ptiede Sep 5, 2024
ee96ec1
Add more testing
ptiede Sep 5, 2024
b8c5370
Add BatchDuplicated (still broken)
ptiede Sep 5, 2024
dc1c979
Remove BatchMode since it isn't applicable?
ptiede Sep 5, 2024
e548322
Add sparse array testing
ptiede Sep 5, 2024
c71470f
Don't support batchmode for now
ptiede Sep 5, 2024
c99ea6c
Revert project to old style
ptiede Sep 5, 2024
582b675
Add sparse array compat bound
ptiede Sep 5, 2024
e432afb
reenable batch mode for bug hunting
ptiede Sep 6, 2024
05924ed
Turn on BatchDuplicated stuff again
ptiede Sep 6, 2024
be8076b
Merge branch 'main' into ptiede-spdensemul
ptiede Sep 6, 2024
85f2df7
Merge branch 'main' into ptiede-spdensemul
ptiede Sep 6, 2024
c63d21b
Remove Q comment
ptiede Sep 6, 2024
5c17efd
Encorporate BatchDuplicated into testing properly
ptiede Sep 6, 2024
26d4429
Merge branch 'main' into ptiede-spdensemul
ptiede Sep 12, 2024
8292d1e
Merge branch 'main' into ptiede-spdensemul
wsmoses Sep 12, 2024
c654b5d
Consider constant fp in runtime activity (#1797)
wsmoses Sep 6, 2024
e9d34bc
Suggest workaround in error for overwritten active by ref (#1791)
danielwe Sep 6, 2024
1a36727
Fix custom active reverse mode check (#1798)
wsmoses Sep 6, 2024
eaaab3c
Remove Q comment
ptiede Sep 6, 2024
e6f67ed
Encorporate BatchDuplicated into testing properly
ptiede Sep 6, 2024
be19ffb
Look for more writebarrier opportunities (#1800)
wsmoses Sep 6, 2024
47b4742
Restrict version to 1.10+ (#1809)
wsmoses Sep 12, 2024
5430264
Update Project.toml
wsmoses Sep 12, 2024
43a16b5
Fix MixedDuplicated ABI error on primalerror (#1815)
wsmoses Sep 12, 2024
cc06e31
Update test
ptiede Sep 27, 2024
291f8f7
Fix ext
ptiede Sep 27, 2024
06c3d8d
Merge branch 'main' into ptiede-spdensemul
ptiede Sep 27, 2024
f752ec8
Move new SparseArrays Cholmod into extension
ptiede Sep 27, 2024
be3c61b
Make LinearAlgebra.mul! explicit
ptiede Sep 27, 2024
df7fb4a
Make sparse arrays not a extension
ptiede Sep 27, 2024
c4e6924
Fix rules for 0.13
ptiede Sep 27, 2024
25b3dcd
Remove sparse arrays ext file
ptiede Sep 27, 2024
85c88f6
Update compiler
ptiede Sep 27, 2024
5bd31f8
Merge branch 'main' into ptiede-spdensemul
ptiede Sep 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ LLVM = "6.1, 7, 8, 9"
LogExpFunctions = "0.3"
ObjectFile = "0.4"
Preferences = "1.4"
SparseArrays = "1"
SpecialFunctions = "1, 2"
StaticArrays = "1"
julia = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient!
export markType, batch_size, onehot, chunkedonehot

using LinearAlgebra
import SparseArrays
import EnzymeCore: ReverseMode, ReverseModeSplit, ForwardMode, Mode

import EnzymeCore: EnzymeRules
Expand Down
129 changes: 128 additions & 1 deletion src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,133 @@ function EnzymeRules.reverse(
return (nothing, nothing)
end


function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig,
func::Const{typeof(LinearAlgebra.mul!)},
::Type{RT},
C::Annotation{<:StridedVecOrMat},
A::Const{<:SparseArrays.SparseMatrixCSCUnion},
B::Annotation{<:StridedVecOrMat},
α::Annotation{<:Number},
β::Annotation{<:Number}
) where {RT}

cache_C = !(isa(β, Const)) ? copy(C.val) : nothing
# Always need to do forward pass otherwise primal may not be correct
func.val(C.val, A.val, B.val, α.val, β.val)

primal = if EnzymeRules.needs_primal(config)
C.val
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
C.dval
else
nothing
end

# Check if A is overwritten and B is active (and thus required)
cache_A = ( EnzymeRules.overwritten(config)[5]
&& !(typeof(B) <: Const)
&& !(typeof(C) <: Const)
) ? copy(A.val) : nothing

# cache_B = ( EnzymeRules.overwritten(config)[6]) ? copy(B.val) : nothing

if !isa(α, Const)
cache_α = A.val*B.val
else
cache_α = nothing
end

cache = (cache_C, cache_A, cache_α)

return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfig,
func::Const{typeof(LinearAlgebra.mul!)},
::Type{RT}, cache,
C::Annotation{<:StridedVecOrMat},
A::Const{<:SparseArrays.SparseMatrixCSCUnion},
B::Annotation{<:StridedVecOrMat},
α::Annotation{<:Number},
β::Annotation{<:Number}
) where {RT}

cache_C, cache_A, cache_α = cache
Cval = !isnothing(cache_C) ? cache_C : C.val
Aval = !isnothing(cache_A) ? cache_A : A.val
# Bval = !isnothing(cache_B) ? cache_B : B.val

N = EnzymeRules.width(config)
if !isa(C, Const)
dCs = C.dval
dBs = isa(B, Const) ? dCs : B.dval

dα = if !isa(α, Const)
if N == 1
LinearAlgebra.dot(C.dval, cache_α)
else
ntuple(Val(N)) do i
Base.@_inline_meta
LinearAlgebra.dot(C.dval[i], cache_α)
end
end
else
nothing
end

dβ = if !isa(β, Const)
if N == 1
LinearAlgebra.dot(C.dval, Cval)
else
ntuple(Val(N)) do i
Base.@_inline_meta
LinearAlgebra.dot(C.dval[i], Cval)
end
end
else
nothing
end

for i in 1:N
# This rule is incorrect since you need to project dA to have the same
# sparsity pattern as A.
# if !isa(A, Const)
# dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b]
# #dA .+= α*dC*B'
# mul!(dA, dC, Bval', α.val, true)
# end

if !isa(B, Const)
#dB .+= α*A'*dC
if N ==1
func.val(dBs, Aval', dCs, α.val, true)
else
func.val(dBs[i], Aval', dCs[i], α.val, true)
end
end

if N==1
dCs .*= β.val
else
dCs[i] .*= β.val
end
end
end

return (nothing, nothing, nothing, dα, dβ)
end







function EnzymeRules.forward(
config::EnzymeRules.FwdConfig,
::Const{typeof(sort!)},
Expand Down Expand Up @@ -1269,4 +1396,4 @@ function EnzymeRules.reverse(
smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}},
) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}}
return (nothing, nothing, nothing)
end
end
42 changes: 42 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -677,4 +677,46 @@ end
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),)
end

@testset "SparseArrays spmatvec reverse rule" begin
C = zeros(18)
M = sprand(18, 9, 0.1)
v = randn(9)
α = 2.0
β = 1.0

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
Tα in (Const, Active), Tβ in (Const, Active)

are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ))

end


for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tret, Tv) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
end
end

@testset "SparseArrays spmatmat reverse rule" begin
C = zeros(18, 11)
M = sprand(18, 9, 0.1)
v = randn(9, 11)
α = 2.0
β = 1.0

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
Tα in (Const, Active), Tβ in (Const, Active)

are_activities_compatible(Tret, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ))
end

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tv) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
end
end

end # InternalRules
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4101,4 +4101,4 @@ include("ext/logexpfunctions.jl")

@testset "BFloat16s ext" begin
include("ext/bfloat16s.jl")
end
end
Loading