Skip to content

Commit

Permalink
Merge pull request #355 from SciML/accelerate
Browse files Browse the repository at this point in the history
Support Apple Accelerate and improve MKL integration
  • Loading branch information
ChrisRackauckas authored Aug 8, 2023
2 parents 464156c + 6d5aeb4 commit aabe2f2
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 5 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Expand Down
12 changes: 8 additions & 4 deletions ext/LinearSolveMKLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(siz
chkstride1(A)
m, n = size(A)
lda = max(1,stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A,1),size(A,2)))
end
ccall((@blasfunc(dgetrf_), MKL_jll.libmkl_rt), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[] #Error code is stored in LU factorization type
A, ipiv, info[], info #Error code is stored in LU factorization type
end

default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
Expand All @@ -30,7 +33,7 @@ default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false
function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), Ref{BlasInt}()
end

function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
Expand All @@ -39,11 +42,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :MKLLUFactorization)
fact = LU(getrf!(A)...)
res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
fact = LU(res[1:3]...), res[4]
cache.cacheval = fact
cache.isfresh = false
end
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization), cache.b)
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization)[1], cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

Expand Down
6 changes: 6 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ using EnumX
using Requires
import InteractiveUtils

using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
@blasfunc, chkargsok

import GPUArraysCore
import Preferences

Expand Down Expand Up @@ -87,6 +91,7 @@ include("solve_function.jl")
include("default.jl")
include("init.jl")
include("extension_algs.jl")
include("appleaccelerate.jl")
include("deprecated.jl")

@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
Expand Down Expand Up @@ -185,6 +190,7 @@ export CudaOffloadFactorization
export MKLPardisoFactorize, MKLPardisoIterate
export PardisoJL
export MKLLUFactorization
export AppleAccelerateLUFactorization

export OperatorAssumptions, OperatorCondition

Expand Down
102 changes: 102 additions & 0 deletions src/appleaccelerate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using LinearAlgebra
using Libdl

# For now, only use BLAS from Accelerate (that is to say, vecLib)
global const libacc = "/System/Library/Frameworks/Accelerate.framework/Accelerate"

"""
```julia
AppleAccelerateLUFactorization()
```
A wrapper over Apple's Accelerate Library. Direct calls to Acceelrate in a way that pre-allocates workspace
to avoid allocations and does not require libblastrampoline.
"""
struct AppleAccelerateLUFactorization <: AbstractFactorization end

function appleaccelerate_isavailable()
libacc_hdl = Libdl.dlopen_e(libacc)
if libacc_hdl == C_NULL
return false
end

if dlsym_e(libacc_hdl, "dgetrf_") == C_NULL
return false
end
return true
end

function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, Cint, min(size(A,1),size(A,2))), info = Ref{Cint}(), check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1,stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, Cint, min(size(A,1),size(A,2)))
end

ccall(("dgetrf_", libacc), Cvoid,
(Ref{Cint}, Ref{Cint}, Ptr{Float64},
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
m, n, A, lda, ipiv, info)
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
end

function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float64}; info = Ref{Cint}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
end
nrhs = size(B, 2)
ccall(("dgetrs_", libacc), Cvoid,
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint},
Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong),
trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B
end

default_alias_A(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false

function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
luinst = ArrayInterface.lu_instance(convert(AbstractMatrix, A))
LU(luinst.factors,similar(A, Cint, 0), luinst.info), Ref{Cint}()
end

function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :AppleAccelerateLUFactorization)
res = aa_getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
fact = LU(res[1:3]...), res[4]
cache.cacheval = fact
cache.isfresh = false
end

A, info = @get_cacheval(cache, :AppleAccelerateLUFactorization)
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
m, n = size(A, 1), size(A, 2)
if m > n
Bc = copy(cache.b)
aa_getrs!('N', A.factors, A.ipiv, Bc; info)
return copyto!(cache.u, 1, Bc, 1, n)
else
copyto!(cache.u, cache.b)
aa_getrs!('N', A.factors, A.ipiv, cache.u; info)
end

SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end
3 changes: 3 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ end
test_interface(alg, prob1, prob2)
end
end
if LinearSolve.appleaccelerate_isavailable()
test_interface(AppleAccelerateLUFactorization(), prob1, prob2)
end
end

@testset "Generic Factorizations" begin
Expand Down
4 changes: 3 additions & 1 deletion test/resolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ using LinearSolve, LinearAlgebra, SparseArrays, InteractiveUtils, Test

for alg in subtypes(LinearSolve.AbstractFactorization)
@show alg
if !(alg in [DiagonalFactorization, CudaOffloadFactorization])
if !(alg in [DiagonalFactorization, CudaOffloadFactorization, AppleAccelerateLUFactorization]) &&
(!(alg == AppleAccelerateLUFactorization) || LinearSolve.appleaccelerate_isavailable())

A = [1.0 2.0; 3.0 4.0]
alg in [KLUFactorization, UMFPACKFactorization, SparspakFactorization] &&
(A = sparse(A))
Expand Down

0 comments on commit aabe2f2

Please sign in to comment.