Skip to content

Commit

Permalink
Left ILU support (#6)
Browse files Browse the repository at this point in the history
* Deal with ILU

* Refactor tests. Fix ILU
  • Loading branch information
michel2323 authored Apr 29, 2024
1 parent 2d1a718 commit fc34100
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 66 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.1.0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
1 change: 1 addition & 0 deletions src/DiffKrylov.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module DiffKrylov
using Krylov
using SparseArrays
using LinearAlgebra
using IncompleteLU
include("ForwardDiff/forwarddiff.jl")
include("EnzymeRules/enzymerules.jl")
end
16 changes: 14 additions & 2 deletions src/EnzymeRules/enzymerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ for AMT in (:Matrix, :SparseMatrixCSC)
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT};
verbose = 0,
M = I,
N = I,
verbose = 0,
options...
) where {RT, MT <: $AMT, VT <: Vector}
psolver = $solver
Expand Down Expand Up @@ -155,7 +155,19 @@ for AMT in (:Matrix, :SparseMatrixCSC)
if verbose > 0
@info "($psolver, $pamt) reverse"
end
_b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, N=N, verbose=verbose, options...)[1]
if M == I
nothing
elseif isa(M, IncompleteLU.ILUFactorization)
U = copy(M.U)
L = copy(M.L)
transpose!(U, M.L)
transpose!(L, M.U)
N = IncompleteLU.ILUFactorization(L, U)
M = I
else
error("Preconditioner not supported")
end
_b.dval .= Krylov.$solver(adjoint(A), bx[]; M=M, N=N, verbose=verbose, options...)[1]
_A.dval .= -x .* _b.dval'
return (nothing, nothing)
end
Expand Down
10 changes: 10 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
21 changes: 21 additions & 0 deletions test/create_matrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
function create_unsymmetric_matrix(n)
# Ensure the size is at least 2
if n < 2
throw(ArgumentError("Matrix size should be at least 2x2"))
end

# Generate a random n x n matrix with entries from a normal distribution
A = randn(n, n)

# Perform Singular Value Decomposition
U, S, V = svd(A)

# Modify the singular values to make them close to each other but not too small
# Here we set them all to be between 1 and 2
S = Diagonal(range(1, stop=2, length=n))

# Reconstruct the matrix
well_conditioned_matrix = U * S * V'

return well_conditioned_matrix
end
84 changes: 28 additions & 56 deletions test/enzymediff.jl
Original file line number Diff line number Diff line change
@@ -1,64 +1,36 @@
using Enzyme
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules
using DiffKrylov
using LinearAlgebra
using FiniteDifferences
using Krylov
using Random
using SparseArrays
using Test

@testset "$solver" for solver = (Krylov.cg, Krylov.gmres, Krylov.bicgstab)
Random.seed!(1)
include("create_matrix.jl")
@testset "Enzyme Rules" begin
@testset "$MT" for MT = (Matrix, SparseMatrixCSC)
A, b = sparse_laplacian(4, FC=Float64)
A = MT(A)
fdm = central_fdm(8, 1);
function A_one_one(x)
_A = copy(A)
_A[1,1] = x
solver(_A,b)
@testset "($M, $N)" for (M,N) = ((I,I),)
# Square unsymmetric solvers
@testset "$solver" for solver = (Krylov.gmres, Krylov.bicgstab)
A = []
if MT == Matrix
A = create_unsymmetric_matrix(10)
b = rand(10)
else
A, b = sparse_laplacian(4, FC=Float64)
end
test_enzyme_with(solver, A, b, M, N)
end
# Square symmetric solvers
@testset "$solver" for solver = (Krylov.cg,)
A, b = sparse_laplacian(4, FC=Float64)
A = MT(A)
test_enzyme_with(solver, A, b, M, N)
end
end

function b_one(x)
_b = copy(b)
_b[1] = x
solver(A,_b)
end

fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a)[1], copy(A[1,1]))
fdb = FiniteDifferences.jacobian(fdm, a -> b_one(a)[1], copy(b[1]))
fd =fda[1] + fdb[1]
# Test forward
function duplicate(A::SparseMatrixCSC)
dA = copy(A)
fill!(dA.nzval, zero(eltype(A)))
return dA
end
duplicate(A::Matrix) = zeros(size(A))

dA = Duplicated(A, duplicate(A))
db = Duplicated(b, zeros(length(b)))
dA.dval[1,1] = 1.0
db.dval[1] = 1.0
dx = Enzyme.autodiff(
Forward,
solver,
dA,
db
)
@test isapprox(dx[1][1], fd, atol=1e-4, rtol=1e-4)
# Test reverse
function driver!(x, A, b)
x .= gmres(A,b)[1]
nothing
end
dA = Duplicated(A, duplicate(A))
db = Duplicated(b, zeros(length(b)))
dx = Duplicated(zeros(length(b)), zeros(length(b)))
dx.dval[1] = 1.0
Enzyme.autodiff(
Reverse,
driver!,
dx,
dA,
db
)

@test isapprox(db.dval[1], fdb[1][1], atol=1e-4, rtol=1e-4)
@test isapprox(dA.dval[1,1], fda[1][1], atol=1e-4, rtol=1e-4)
end
end
8 changes: 0 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,8 @@ using ForwardDiff
import ForwardDiff: Dual, Partials, partials, value
using FiniteDifferences

include("get_div_grad.jl")
include("utils.jl")

# Sparse Laplacian.
function sparse_laplacian(n :: Int=16; FC=Float64)
A = get_div_grad(n, n, n)
b = ones(n^3)
return A, b
end

atol = 1e-12
rtol = 0.0
@testset "DiffKrylov" begin
Expand Down
102 changes: 102 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# Sparse Laplacian.
include("get_div_grad.jl")
function sparse_laplacian(n :: Int=16; FC=Float64)
A = get_div_grad(n, n, n)
b = ones(n^3)
return A, b
end

function check(A,b)
tA, tb = sparse_laplacian(4, FC=Float64)
@test all(value.(tb) .== b)
Expand Down Expand Up @@ -75,3 +83,97 @@ function check_derivatives_and_values_active_passive(solver, A, b, x)
isapprox(value.(dx), x)
@test isapprox(partials.(dx,1), fda[1])
end

struct GMRES end
struct BICGSTAB end
struct CG end

function driver!(::GMRES, x, A, b, M, N, ldiv=false)
x .= gmres(A,b, atol=1e-16, rtol=1e-16, M=M, N=N, verbose=0, ldiv=ldiv)[1]
nothing
end

function driver!(::BICGSTAB, x, A, b, M, N, ldiv=false)
x .= bicgstab(A,b, atol=1e-16, rtol=1e-16, M=M, N=N, verbose=0, ldiv=ldiv)[1]
nothing
end

function driver!(::CG, x, A, b, M, N, ldiv=false)
x .= cg(A,b, atol=1e-16, rtol=1e-16, M=M, verbose=0, ldiv=ldiv)[1]
nothing
end

function test_enzyme_with(solver, A, b, M, N, ldiv=false)
tsolver = if solver == Krylov.cg
CG()
elseif solver == Krylov.gmres
GMRES()
elseif solver == Krylov.bicgstab
BICGSTAB()
else
error("Unsupported solver $solver is tested in DiffKrylov.jl")
end
fdm = central_fdm(8, 1);
function A_one_one(hx)
_A = copy(A)
_A[1,1] = hx
x = zeros(length(b))
driver!(tsolver, x, _A, b, M, N, ldiv)
return x
end

function b_one(hx)
_b = copy(b)
_b[1] = hx
x = zeros(length(b))
driver!(tsolver, x, A, _b, M, N, ldiv)
return x
end

fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a), copy(A[1,1]))
fdb = FiniteDifferences.jacobian(fdm, a -> b_one(a), copy(b[1]))
fd =fda[1] + fdb[1]
# Test forward
function duplicate(A::SparseMatrixCSC)
dA = copy(A)
fill!(dA.nzval, zero(eltype(A)))
return dA
end
duplicate(A::Matrix) = zeros(size(A))

dA = Duplicated(A, duplicate(A))
db = Duplicated(b, zeros(length(b)))
dx = Duplicated(zeros(length(b)), zeros(length(b)))
dA.dval[1,1] = 1.0
db.dval[1] = 1.0
Enzyme.autodiff(
Forward,
driver!,
Const(tsolver),
dx,
dA,
db,
Const(M),
Const(N),
Const(ldiv)
)
@test isapprox(dx.dval, fd, atol=1e-4, rtol=1e-4)
# Test reverse
dA = Duplicated(A, duplicate(A))
db = Duplicated(b, zeros(length(b)))
dx = Duplicated(zeros(length(b)), zeros(length(b)))
dx.dval[1] = 1.0
Enzyme.autodiff(
Reverse,
driver!,
Const(tsolver),
dx,
dA,
db,
Const(M),
Const(N),
Const(ldiv)
)
@test isapprox(db.dval[1], fdb[1][1], atol=1e-4, rtol=1e-4)
@test isapprox(dA.dval[1,1], fda[1][1], atol=1e-4, rtol=1e-4)
end

0 comments on commit fc34100

Please sign in to comment.