Skip to content

Commit

Permalink
Full ForwardDiff.jl support
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Dec 8, 2023
1 parent 83f7d17 commit 75ddf8b
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 183 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# DiffKrylov

DiffKrylov provides a differentiable API for [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) using [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). In the future, we will also support Enzyme and its reverse mode.
154 changes: 88 additions & 66 deletions src/ForwardDiff/krylov.jl
Original file line number Diff line number Diff line change
@@ -1,77 +1,99 @@
import ForwardDiff: Dual, Partials, partials, value

for solver in (:cg, :gmres, :bicgstab)
@eval begin
function Krylov.$solver(_A::SparseMatrixCSC{V, Int64}, _b::Vector{Dual{T, V, N}}; options...) where {T, V, N}
A = SparseMatrixCSC(_A.m, _A.n, _A.colptr, _A.rowval, value.(_A.nzval))
b = value.(_b)
m = length(b)
dbs = Matrix{V}(undef, m, N)
for i in 1:m
dbs[i,:] = partials(_b[i])
end
x, stats = $solver(A,b; options...)
dxs = Matrix{Float64}(undef, m, N)
px = Vector{Partials{N,V}}(undef, m)
_matrix_values(A::SparseMatrixCSC{Dual{T, V, N}, IT}) where {T, V, N, IT} = SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, value.(A.nzval))
_matrix_values(A::Matrix{Dual{T, V, N}}) where {T, V, N} = Matrix{V}(value.(A))
function _matrix_partials(A::SparseMatrixCSC{Dual{T, V, N}, IT}) where {T, V, N, IT}
dAs = Vector{SparseMatrixCSC{Float64, Int64}}(undef, N)
for i in 1:N
nb = dbs[:,i]
dx, dstats = $solver(A,nb; options...)
dxs[:,i] = dx
dAs[i] = SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, partials.(A.nzval, i))
end
for i in 1:m
px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N))
end
duals = Dual{T,V,N}.(x, px)
return (duals, stats)
return dAs
end
function _matrix_partials(A::Matrix{Dual{T, V, N}}) where {T, V, N}
dAs = Vector{Matrix{V}}(undef, N)
for i in 1:N
dAs[i] = Matrix(partials.(A, i))
end
return dAs
end

function Krylov.$solver(A::Matrix{V}, _b::Vector{Dual{T, V, N}}; options...) where {T, V, N}
b = value.(_b)
m = length(b)
dbs = Matrix{V}(undef, m, N)
for i in 1:m
dbs[i,:] = partials(_b[i])
end
x, stats = $solver(A,b; options...)
dxs = Matrix{Float64}(undef, m, N)
px = Vector{Partials{N,V}}(undef, m)
for i in 1:N
nb = dbs[:,i]
dx, dstats = $solver(A,nb; options...)
dxs[:,i] = dx

for solver in (:cg, :gmres, :bicgstab)
for matrix in (:(SparseMatrixCSC{V, IT}), :(Matrix{V}))
@eval begin
function Krylov.$solver(A::$matrix, _b::Vector{Dual{T, V, N}}; options...) where {T, V, N, IT}
b = value.(_b)
m = length(b)
dbs = Matrix{V}(undef, m, N)
for i in 1:m
dbs[i,:] = partials(_b[i])
end
x, stats = $solver(A,b; options...)
dxs = Matrix{V}(undef, m, N)
px = Vector{Partials{N,V}}(undef, m)
for i in 1:N
nb = dbs[:,i]
dx, dstats = $solver(A,nb; options...)
dxs[:,i] = dx
end
for i in 1:m
px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N))
end
duals = Dual{T,V,N}.(x, px)
return (duals, stats)
end
end
end
for i in 1:m
px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N))

for matrix in (:(SparseMatrixCSC{Dual{T,V,N}, IT}), :(Matrix{Dual{T,V,N}}))
@eval begin
function Krylov.$solver(_A::$matrix, b::Vector{V}; options...) where {T, V, N, IT}
A = _matrix_values(_A)
dAs = _matrix_partials(_A)
m = length(b)
x, stats = $solver(A,b)
dxs = Matrix{Float64}(undef, m, N)
px = Vector{Partials{N,V}}(undef, m)
for i in 1:N
nb = - dAs[i]*x
dx, dstats = $solver(A,nb)
dxs[:,i] = dx
end
for i in 1:m
px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N))
end
duals = Dual{T,V,N}.(x, px)
return (duals, stats)
end
end
end
duals = Dual{T,V,N}.(x, px)
return (duals, stats)

for matrix in (:(SparseMatrixCSC{Dual{T,V,N}, IT}), :(Matrix{Dual{T,V,N}}))
@eval begin
function Krylov.$solver(_A::$matrix, _b::Vector{Dual{T, V, N}}; options...) where {T, V, N, IT}
A = _matrix_values(_A)
dAs = _matrix_partials(_A)
b = value.(_b)
m = length(b)
dbs = Matrix{V}(undef, m, N)
for i in 1:m
dbs[i,:] = partials(_b[i])
end
x, stats = $solver(A,b)
dxs = Matrix{Float64}(undef, m, N)
px = Vector{Partials{N,V}}(undef, m)
for i in 1:N
nb = dbs[:,i] - dAs[i]*x
dx, dstats = $solver(A,nb)
dxs[:,i] = dx
end
for i in 1:m
px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N))
end
duals = Dual{T,V,N}.(x, px)
return (duals, stats)
end
end
end
end

# function Krylov.cg(_A::SparseMatrixCSC{Dual{T, V, NA}, Int64}, _b::Vector{Dual{T, V, NB}}; options...) where {T, V, NA, NB}
# A = SparseMatrixCSC(_A.m, _A.n, _A.colptr, _A.rowval, value.(_A.nzval))
# dAs = Vector{SparseMatrixCSC{Float64, Int64}}(undef, NA)
# for i in 1:NA
# dAs[i] = SparseMatrixCSC(_A.m, _A.n, _A.colptr, _A.rowval, partials.(_A.nzval, i))
# end
# b = value.(_b)
# m = length(b)
# dbs = Matrix{V}(undef, m, NB)
# for i in 1:m
# dbs[i,:] = partials(_b[i])
# end
# x, stats = cg(A,b)
# dxs = Matrix{Float64}(undef, m, N)
# px = Vector{Partials{N,V}}(undef, n)
# for i in 1:N
# nb = dbs[:,i] - dAs[i]*x
# dx, dstats = cg(A[i],nb)
# dxs[:,i] = dx
# end
# for i in 1:m
# px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N))
# end
# duals = Dual{T,V,N}.(x, px)
# return (duals, stats)
# end
end
117 changes: 0 additions & 117 deletions test/autodiff.jl

This file was deleted.

46 changes: 46 additions & 0 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Sparse Laplacian.
include("get_div_grad.jl")
include("utils.jl")
solver = Krylov.cg
function sparse_laplacian(n :: Int=16; FC=Float64)
A = get_div_grad(n, n, n)
b = ones(n^3)
return A, b
end
A, b = sparse_laplacian(4, FC=Float64)
@testset "$solver" for solver = (Krylov.cg, Krylov.gmres, Krylov.bicgstab)
x, stats = cg(A,b)

# A passive, b active
# Sparse
@testset "A sparse passive, b active" begin
check_jacobian(solver, A, b)
check_values(solver, A, b)
end
# Dense
@testset "A dense passive, b active" begin
denseA = Matrix(A)
check_jacobian(solver, denseA, b)
check_values(solver, denseA, b)
end

# A active, b active
# Sparse
@testset "A sparse active, b active" begin
check_derivatives_and_values_active_active(solver, A, b, x)
end
# Dense
@testset "A dense active, b active" begin
check_derivatives_and_values_active_active(solver, Matrix(A), b, x)
end

# A active, b passive
# Sparse
@testset "A sparse active, b active" begin
check_derivatives_and_values_active_passive(solver, A, b, x)
end
# Dense
@testset "A dense active, b active" begin
check_derivatives_and_values_active_passive(solver, Matrix(A), b, x)
end
end
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Krylov
using DiffKrylov
using Test
using LinearAlgebra
using SparseArrays
using ForwardDiff
import ForwardDiff: Dual, Partials, partials, value
using FiniteDifferences

@testset "DiffKrylov" begin
@testset "ForwardDiff" begin
include("forwarddiff.jl")
end
end
Loading

0 comments on commit 75ddf8b

Please sign in to comment.