From 8dd0a06d33bb35d4336e7b71efe2a0da55578ab1 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Wed, 30 Jan 2019 10:59:43 -0800 Subject: [PATCH] Allow constructing Cholesky objects directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows recording `Cholesky` calls in `Branch`es and using them with `∇`. Since no computation is involved when constructing a `Cholesky` object (it's assumed on construction that the matrix passed to the constructor is the factorized matrix), it just returns the input. --- .../linalg/factorization/cholesky.jl | 22 +++++++++++++++++- .../linalg/factorization/cholesky.jl | 23 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/sensitivities/linalg/factorization/cholesky.jl b/src/sensitivities/linalg/factorization/cholesky.jl index 9eb04be4..5cad9ac5 100644 --- a/src/sensitivities/linalg/factorization/cholesky.jl +++ b/src/sensitivities/linalg/factorization/cholesky.jl @@ -1,5 +1,5 @@ import LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! -import LinearAlgebra: cholesky +import LinearAlgebra: cholesky, Cholesky import Base: getproperty Base.@deprecate chol(X) cholesky(X).U @@ -31,6 +31,26 @@ function ∇(::typeof(getproperty), ::Type{Arg{1}}, p, y, ȳ, C::Cholesky, x::S end end +@explicit_intercepts( + Cholesky, + Tuple{AbstractMatrix{<:∇Scalar}, Union{Char, Symbol}, Integer}, + [true, false, false], +) +function ∇( + ::Type{Cholesky}, + ::Type{Arg{1}}, + p, + C::Cholesky, + X̄::Union{UpperTriangular, LowerTriangular}, + X::Union{UpperTriangular, LowerTriangular}, + uplo::Union{Char, Symbol}, + info::Integer, +) + # We aren't doing any actual computation if we've constructed a Cholesky object + # directly, so just pass through this call and return the sensitivies + return X̄ +end + """ level2partition(A::AbstractMatrix, j::Int, upper::Bool) diff --git a/test/sensitivities/linalg/factorization/cholesky.jl b/test/sensitivities/linalg/factorization/cholesky.jl index fe17713d..4703a0c5 100644 --- a/test/sensitivities/linalg/factorization/cholesky.jl +++ b/test/sensitivities/linalg/factorization/cholesky.jl @@ -65,4 +65,27 @@ @test_throws ArgumentError ∇(X->cholesky(X).info)(X_) end + + let + X_ = Matrix{Float64}(I, 5, 5) + X = Leaf(Tape(), X_) + U = cholesky(X).U + C = Cholesky(U, 'U', 0) + @test C isa Branch{<:Cholesky} + @test getfield(C, :f) == LinearAlgebra.Cholesky + @test unbox(C) == Cholesky(UpperTriangular(X_), 'U', 0) + # Ensure we can still directly extract the .U field + UU = C.U + @test UU isa Branch{<:UpperTriangular} + # And access .L as well + LL = C.L + @test LL isa Branch{<:LowerTriangular} + # Make sure that computing the Cholesky and already having the Cholesky + # produce the same results + expected = Matrix(0.5I, 5, 5) + @test ∇(X->det(cholesky(X).U))(X_)[1] ≈ expected + @test ∇(X->det(cholesky(X).L))(X_)[1] ≈ expected + @test ∇(X->det(Cholesky(cholesky(X).U, :U, 0).U))(X_)[1] ≈ expected + @test ∇(X->det(Cholesky(cholesky(X).L, 'L', 0).U))(X_)[1] ≈ expected + end end