diff --git a/Project.toml b/Project.toml index 7082957..be2e655 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.1.0" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/src/DiffKrylov.jl b/src/DiffKrylov.jl index 9ff750a..e3032d2 100644 --- a/src/DiffKrylov.jl +++ b/src/DiffKrylov.jl @@ -3,6 +3,7 @@ module DiffKrylov using Krylov using SparseArrays using LinearAlgebra +using IncompleteLU include("ForwardDiff/forwarddiff.jl") include("EnzymeRules/enzymerules.jl") end diff --git a/src/EnzymeRules/enzymerules.jl b/src/EnzymeRules/enzymerules.jl index 3a589c4..5dc2545 100644 --- a/src/EnzymeRules/enzymerules.jl +++ b/src/EnzymeRules/enzymerules.jl @@ -12,9 +12,9 @@ for AMT in (:Matrix, :SparseMatrixCSC) ret::Type{RT}, _A::Annotation{MT}, _b::Annotation{VT}; - verbose = 0, M = I, N = I, + verbose = 0, options... ) where {RT, MT <: $AMT, VT <: Vector} psolver = $solver @@ -155,7 +155,19 @@ for AMT in (:Matrix, :SparseMatrixCSC) if verbose > 0 @info "($psolver, $pamt) reverse" end - _b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, N=N, verbose=verbose, options...)[1] + if M == I + nothing + elseif isa(M, IncompleteLU.ILUFactorization) + U = copy(M.U) + L = copy(M.L) + transpose!(U, M.L) + transpose!(L, M.U) + N = IncompleteLU.ILUFactorization(L, U) + M = I + else + error("Preconditioner not supported") + end + _b.dval .= Krylov.$solver(adjoint(A), bx[]; M=M, N=N, verbose=verbose, options...)[1] _A.dval .= -x .* _b.dval' return (nothing, nothing) end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..91ff4b0 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,10 @@ +[deps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895" +Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/create_matrix.jl b/test/create_matrix.jl new file mode 100644 index 0000000..9592685 --- /dev/null +++ b/test/create_matrix.jl @@ -0,0 +1,21 @@ +function create_unsymmetric_matrix(n) + # Ensure the size is at least 2 + if n < 2 + throw(ArgumentError("Matrix size should be at least 2x2")) + end + + # Generate a random n x n matrix with entries from a normal distribution + A = randn(n, n) + + # Perform Singular Value Decomposition + U, S, V = svd(A) + + # Modify the singular values to make them close to each other but not too small + # Here we set them all to be between 1 and 2 + S = Diagonal(range(1, stop=2, length=n)) + + # Reconstruct the matrix + well_conditioned_matrix = U * S * V' + + return well_conditioned_matrix +end diff --git a/test/enzymediff.jl b/test/enzymediff.jl index 6fd0784..a97e12f 100644 --- a/test/enzymediff.jl +++ b/test/enzymediff.jl @@ -1,64 +1,36 @@ using Enzyme import .EnzymeRules: forward, reverse, augmented_primal using .EnzymeRules +using DiffKrylov +using LinearAlgebra +using FiniteDifferences +using Krylov +using Random +using SparseArrays +using Test -@testset "$solver" for solver = (Krylov.cg, Krylov.gmres, Krylov.bicgstab) +Random.seed!(1) +include("create_matrix.jl") +@testset "Enzyme Rules" begin @testset "$MT" for MT = (Matrix, SparseMatrixCSC) - A, b = sparse_laplacian(4, FC=Float64) - A = MT(A) - fdm = central_fdm(8, 1); - function A_one_one(x) - _A = copy(A) - _A[1,1] = x - solver(_A,b) + @testset "($M, $N)" for (M,N) = ((I,I),) + # Square unsymmetric solvers + @testset "$solver" for solver = (Krylov.gmres, Krylov.bicgstab) + A = [] + if MT == Matrix + A = create_unsymmetric_matrix(10) + b = rand(10) + else + A, b = sparse_laplacian(4, FC=Float64) + end + test_enzyme_with(solver, A, b, M, N) + end + # Square symmetric solvers + @testset "$solver" for solver = (Krylov.cg,) + A, b = sparse_laplacian(4, FC=Float64) + A = MT(A) + test_enzyme_with(solver, A, b, M, N) + end end - - function b_one(x) - _b = copy(b) - _b[1] = x - solver(A,_b) - end - - fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a)[1], copy(A[1,1])) - fdb = FiniteDifferences.jacobian(fdm, a -> b_one(a)[1], copy(b[1])) - fd =fda[1] + fdb[1] - # Test forward - function duplicate(A::SparseMatrixCSC) - dA = copy(A) - fill!(dA.nzval, zero(eltype(A))) - return dA - end - duplicate(A::Matrix) = zeros(size(A)) - - dA = Duplicated(A, duplicate(A)) - db = Duplicated(b, zeros(length(b))) - dA.dval[1,1] = 1.0 - db.dval[1] = 1.0 - dx = Enzyme.autodiff( - Forward, - solver, - dA, - db - ) - @test isapprox(dx[1][1], fd, atol=1e-4, rtol=1e-4) - # Test reverse - function driver!(x, A, b) - x .= gmres(A,b)[1] - nothing - end - dA = Duplicated(A, duplicate(A)) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - dx.dval[1] = 1.0 - Enzyme.autodiff( - Reverse, - driver!, - dx, - dA, - db - ) - - @test isapprox(db.dval[1], fdb[1][1], atol=1e-4, rtol=1e-4) - @test isapprox(dA.dval[1,1], fda[1][1], atol=1e-4, rtol=1e-4) end end diff --git a/test/runtests.jl b/test/runtests.jl index 74241b4..ecfda33 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,16 +7,8 @@ using ForwardDiff import ForwardDiff: Dual, Partials, partials, value using FiniteDifferences -include("get_div_grad.jl") include("utils.jl") -# Sparse Laplacian. -function sparse_laplacian(n :: Int=16; FC=Float64) - A = get_div_grad(n, n, n) - b = ones(n^3) - return A, b -end - atol = 1e-12 rtol = 0.0 @testset "DiffKrylov" begin diff --git a/test/utils.jl b/test/utils.jl index da621ac..3ca1887 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,3 +1,11 @@ +# Sparse Laplacian. +include("get_div_grad.jl") +function sparse_laplacian(n :: Int=16; FC=Float64) + A = get_div_grad(n, n, n) + b = ones(n^3) + return A, b +end + function check(A,b) tA, tb = sparse_laplacian(4, FC=Float64) @test all(value.(tb) .== b) @@ -75,3 +83,97 @@ function check_derivatives_and_values_active_passive(solver, A, b, x) isapprox(value.(dx), x) @test isapprox(partials.(dx,1), fda[1]) end + +struct GMRES end +struct BICGSTAB end +struct CG end + +function driver!(::GMRES, x, A, b, M, N, ldiv=false) + x .= gmres(A,b, atol=1e-16, rtol=1e-16, M=M, N=N, verbose=0, ldiv=ldiv)[1] + nothing +end + +function driver!(::BICGSTAB, x, A, b, M, N, ldiv=false) + x .= bicgstab(A,b, atol=1e-16, rtol=1e-16, M=M, N=N, verbose=0, ldiv=ldiv)[1] + nothing +end + +function driver!(::CG, x, A, b, M, N, ldiv=false) + x .= cg(A,b, atol=1e-16, rtol=1e-16, M=M, verbose=0, ldiv=ldiv)[1] + nothing +end + +function test_enzyme_with(solver, A, b, M, N, ldiv=false) + tsolver = if solver == Krylov.cg + CG() + elseif solver == Krylov.gmres + GMRES() + elseif solver == Krylov.bicgstab + BICGSTAB() + else + error("Unsupported solver $solver is tested in DiffKrylov.jl") + end + fdm = central_fdm(8, 1); + function A_one_one(hx) + _A = copy(A) + _A[1,1] = hx + x = zeros(length(b)) + driver!(tsolver, x, _A, b, M, N, ldiv) + return x + end + + function b_one(hx) + _b = copy(b) + _b[1] = hx + x = zeros(length(b)) + driver!(tsolver, x, A, _b, M, N, ldiv) + return x + end + + fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a), copy(A[1,1])) + fdb = FiniteDifferences.jacobian(fdm, a -> b_one(a), copy(b[1])) + fd =fda[1] + fdb[1] + # Test forward + function duplicate(A::SparseMatrixCSC) + dA = copy(A) + fill!(dA.nzval, zero(eltype(A))) + return dA + end + duplicate(A::Matrix) = zeros(size(A)) + + dA = Duplicated(A, duplicate(A)) + db = Duplicated(b, zeros(length(b))) + dx = Duplicated(zeros(length(b)), zeros(length(b))) + dA.dval[1,1] = 1.0 + db.dval[1] = 1.0 + Enzyme.autodiff( + Forward, + driver!, + Const(tsolver), + dx, + dA, + db, + Const(M), + Const(N), + Const(ldiv) + ) + @test isapprox(dx.dval, fd, atol=1e-4, rtol=1e-4) + # Test reverse + dA = Duplicated(A, duplicate(A)) + db = Duplicated(b, zeros(length(b))) + dx = Duplicated(zeros(length(b)), zeros(length(b))) + dx.dval[1] = 1.0 + Enzyme.autodiff( + Reverse, + driver!, + Const(tsolver), + dx, + dA, + db, + Const(M), + Const(N), + Const(ldiv) + ) + @test isapprox(db.dval[1], fdb[1][1], atol=1e-4, rtol=1e-4) + @test isapprox(dA.dval[1,1], fda[1][1], atol=1e-4, rtol=1e-4) +end