Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt to pending Enzyme breaking change #543

Merged
merged 7 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
version: ['1']
group:
- Core
- Enzyme
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- "LinearSolveHYPRE"
- "LinearSolvePardiso"
- "LinearSolveBandedMatrices"
- "Enzyme"
uses: "SciML/.github/.github/workflows/tests.yml@v1"
with:
group: "${{ matrix.group }}"
Expand Down
11 changes: 5 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
Expand All @@ -52,8 +51,8 @@ LinearSolveBandedMatricesExt = "BandedMatrices"
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveCUDSSExt = "CUDSS"
LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"]
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]
LinearSolveEnzymeExt = "EnzymeCore"
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand All @@ -74,8 +73,8 @@ ChainRulesCore = "1.22"
ConcreteStructs = "0.2.3"
DocStringExtensions = "0.9.3"
EnumX = "1.0.4"
Enzyme = "0.11.15, 0.12, 0.13"
EnzymeCore = "0.6.5, 0.7, 0.8"
Enzyme = "0.13"
EnzymeCore = "0.8"
FastAlmostBandedMatrices = "0.1"
FastLapackInterface = "2"
FiniteDiff = "2.22"
Expand All @@ -84,7 +83,7 @@ GPUArraysCore = "0.1.6"
HYPRE = "1.4.0"
InteractiveUtils = "1.10"
IterativeSolvers = "0.9.3"
JET = "0.8.28"
JET = "0.8.28, 0.9"
KLU = "0.6"
KernelAbstractions = "0.9.16"
Krylov = "0.9"
Expand Down
58 changes: 35 additions & 23 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,50 @@ module LinearSolveEnzymeExt

using LinearSolve
using LinearSolve.LinearAlgebra
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)

using Enzyme

using EnzymeCore
using EnzymeCore: EnzymeRules

function EnzymeCore.EnzymeRules.forward(
function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@assert !(prob isa Const)
res = func.val(prob.val, alg.val; kwargs...)
if RT <: Const
return res
if EnzymeRules.needs_primal(config)
return res
else
return nothing
end
end
dres = func.val(prob.dval, alg.val; kwargs...)
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A
if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(res, dres)
elseif EnzymeRules.needs_shadow(config)
return dres
elseif EnzymeRules.needs_primal(config)
return res
else
return nothing
end
error("Unsupported return type $RT")
end

function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
@assert !(linsolve isa Const)

res = func.val(linsolve.val; kwargs...)

if RT <: Const
return res
if EnzymeRules.needs_primal(config)
return res
else
return nothing
end
end
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
Expand All @@ -50,16 +60,18 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},

linsolve.val.b = b

if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(res, dres)
elseif EnzymeRules.needs_shadow(config)
return dres
elseif EnzymeRules.needs_primal(config)
return res
else
return nothing
end

return Duplicated(res, dres)
end

function EnzymeCore.EnzymeRules.augmented_primal(
function EnzymeRules.augmented_primal(
config, func::Const{typeof(LinearSolve.init)},
::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const;
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
Expand Down Expand Up @@ -94,10 +106,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
(dval.b for dval in prob.dval)
end

return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
return EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
end

function EnzymeCore.EnzymeRules.reverse(
function EnzymeRules.reverse(
config, func::Const{typeof(LinearSolve.init)}, ::Type{RT},
cache, prob::EnzymeCore.Annotation{LP}, alg::Const;
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
Expand Down Expand Up @@ -131,7 +143,7 @@ end
# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(
function EnzymeRules.augmented_primal(
config, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
Expand Down Expand Up @@ -184,10 +196,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
cachesolve = deepcopy(linsolve.val)

cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
return EnzymeRules.AugmentedReturn(res, dres, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys, _linsolve, dAs, dbs = cache
Expand Down
43 changes: 17 additions & 26 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff
using SafeTestsets

n = 4
A = rand(n, n);
Expand Down Expand Up @@ -178,47 +177,39 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
for alg in (

function fnice(A, b, alg)
prob = LinearProblem(A, b)
sol1 = solve(prob, alg)
return sum(sol1.u)
end

@testset for alg in (
LUFactorization(),
RFLUFactorization() # KrylovJL_GMRES(), fails
)
@show alg
function fb(b)
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)
fb_closure = b -> fnice(A, b, alg)

sum(sol1.u)
end
fb(b1)

fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
fd_jac = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec
@show fd_jac

en_jac = map(onehot(b1)) do db1
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
Copy link
Member Author

@avik-pal avik-pal Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses with runtime activity I get zeros as the jacobian (the incorrect result was before a patch).
Also I refactored it to not have any closure but still we are seeing the

ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for:   store {} addrspace(10)* %1, {} addrspace(10)** %.fca.1.gep4, align 8, !dbg !43, !noalias !32 const val: {} addrspace(10)* %1
 value=Unknown object of type Vector{Float64}
 llvalue={} addrspace(10)* %1

Stacktrace:
 [1] #solve#7
   @ /mnt/software/sciml/LinearSolve.jl/src/common.jl:270
 [2] solve
   @ /mnt/software/sciml/LinearSolve.jl/src/common.jl:268
 [3] fnice
   @ ./REPL[111]:3
 [4] fnice
   @ ./REPL[111]:0

eres[1]
return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice,
Const(A), Duplicated(b1, db1), Const(alg)))
end |> collect
@show en_jac

@test en_jac≈fd_jac rtol=1e-4

function fA(A)
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)
fA_closure = A -> fnice(A, b1, alg)

sum(sol1.u)
end
fA(A)

fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
fd_jac = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
@show fd_jac

en_jac = map(onehot(A)) do dA
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
eres[1]
end |> collect
return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice,
Duplicated(A, dA), Const(b1), Const(alg)))
end |> collect |> (x -> reshape(x, n, n))
@show en_jac

@test en_jac≈fd_jac rtol=1e-4
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ if GROUP == "All" || GROUP == "Core"
@time @safetestset "Non-Square Tests" include("nonsquare.jl")
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
@time @safetestset "Default Alg Tests" include("default_algs.jl")
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
@time @safetestset "Traits" include("traits.jl")
@time @safetestset "BandedMatrices" include("banded.jl")
@time @safetestset "Static Arrays" include("static_arrays.jl")
end

if GROUP == "All" || GROUP == "Enzyme"
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
end

if GROUP == "LinearSolveCUDA"
Pkg.activate("gpu")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Expand Down
Loading