Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Implement sensitivities for SVD (#131)
Browse files Browse the repository at this point in the history
This supports `svd` and retrieving the `U`, `S`, or `V` property of the
resulting `SVD` object.
  • Loading branch information
ararslan authored Feb 18, 2019
1 parent 689cff3 commit 2026960
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/Nabla.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,6 @@ module Nabla
include("sensitivities/linalg/diagonal.jl")
include("sensitivities/linalg/triangular.jl")
include("sensitivities/linalg/factorization/cholesky.jl")
include("sensitivities/linalg/factorization/svd.jl")

end # module Nabla
117 changes: 117 additions & 0 deletions src/sensitivities/linalg/factorization/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import LinearAlgebra: svd
import Base: getproperty

@explicit_intercepts svd Tuple{AbstractMatrix{<:Real}}

(::typeof(svd), ::Type{Arg{1}}, p, USV::SVD, S̄::AbstractVector, A::AbstractMatrix) =
svd_rev(USV, zeroslike(USV.U), S̄, zeroslike(USV.V))
(::typeof(svd), ::Type{Arg{1}}, p, USV::SVD, V̄::Adjoint, A::AbstractMatrix) =
svd_rev(USV, zeroslike(USV.U), zeroslike(USV.S), V̄)
(::typeof(svd), ::Type{Arg{1}}, p, USV::SVD, Ū::AbstractMatrix, A::AbstractMatrix) =
svd_rev(USV, Ū, zeroslike(USV.S), zeroslike(USV.V))

@explicit_intercepts getproperty Tuple{SVD, Symbol} [true, false]

function (::typeof(getproperty), ::Type{Arg{1}}, p, y, ȳ, USV::SVD, x::Symbol)
if x === :S
return vec(ȳ)
elseif x === :U
return reshape(ȳ, size(USV.U))
elseif x === :V
# This is so we can ensure that the result is an Adjoint, otherwise dispatch
# won't work properly
return copy(ȳ')'
elseif x === :Vt
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
else
throw(ArgumentError("unrecognized property $x; expected U, S, or V"))
end
end

"""
svd_rev(USV, Ū, S̄, V̄)
Compute the reverse mode sensitivities of the singular value decomposition (SVD). `USV` is
an `SVD` factorization object produced by a call to `svd`, and `Ū`, `S̄`, and `V̄` are the
respective sensitivities of the `U`, `S`, and `V` factors.
"""
function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix)
# Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default
U = USV.U
s = USV.S
V = USV.V
Vt = USV.Vt

k = length(s)
T = eltype(s)
F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k]

# We do a lot of matrix operations here, so we'll try to be memory-friendly and do
# as many of the computations in-place as possible. Benchmarking shows that the in-
# place functions here are significantly faster than their out-of-place, naively
# implemented counterparts, and allocate no additional memory.
Ut = U'
FUᵀŪ = mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU)
FVᵀV̄ = mulsubtrans!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV)
ImUUᵀ = eyesubx!(U*Ut) # I - UUᵀ
ImVVᵀ = eyesubx!(V*Vt) # I - VVᵀ

S = Diagonal(s)
= Diagonal(s̄)

= add!(U*FUᵀŪ*S, ImUUᵀ*(Ū/S))*Vt
add!(Ā, U**Vt)
add!(Ā, U*add!(S*FVᵀV̄*Vt, (S\')*ImVVᵀ))

return
end

"""
mulsubtrans!(X::AbstractMatrix, F::AbstractMatrix)
Compute `F .* (X - X')`, overwriting `X` in the process.
!!! note
This is an internal function that does no argument checking; the matrices passed to
this function are square with matching dimensions by construction.
"""
function mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
if i == j
X[i,i] = zero(T)
else
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
X
end

"""
eyesubx!(X::AbstractMatrix)
Compute `I - X`, overwriting `X` in the process.
"""
function eyesubx!(X::AbstractMatrix{T}) where T<:Real
n, m = size(X)
@inbounds for j = 1:m, i = 1:n
X[i,j] = (i == j) - X[i,j]
end
X
end

"""
add!(X::AbstractMatrix, Y::AbstractMatrix)
Compute `X + Y`, overwriting X in the process.
!!! note
This is an internal function that does no argument checking; the matrices passed to
this function are square with matching dimensions by construction.
"""
function add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
X
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ end

@testset "Factorisations" begin
include("sensitivities/linalg/factorization/cholesky.jl")
include("sensitivities/linalg/factorization/svd.jl")
end
end
end
41 changes: 41 additions & 0 deletions test/sensitivities/linalg/factorization/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
@testset "SVD" begin
@testset "Comparison with finite differencing" begin
rng = MersenneTwister(12345)
for n in [4, 6, 10], m in [3, 5, 10]
k = min(n, m)
A = randn(rng, n, m)
VA = randn(rng, n, m)
@test check_errs(X->svd(X).U, randn(rng, n, k), A, VA)
@test check_errs(X->svd(X).S, randn(rng, k), A, VA)
@test check_errs(X->svd(X).V, randn(rng, m, k), A, VA)
end
end

@testset "Error conditions" begin
rng = MersenneTwister(12345)
A = randn(rng, 5, 3)
V̄t = randn(rng, 3, 3)
@test_throws ArgumentError check_errs(X->svd(X).Vt, V̄t, A, A)
@test_throws ErrorException check_errs(X->svd(X).whoops, V̄t, A, A)
end

@testset "Branch consistency" begin
X_ = Matrix{Float64}(I, 3, 5)
X = Leaf(Tape(), X_)
USV = svd(X)
@test USV isa Branch{<:SVD}
@test getfield(USV, :f) == svd
@test unbox(USV.U) Matrix{Float64}(I, 3, 3)
@test unbox(USV.S) ones(Float64, 3)
@test unbox(USV.V) Matrix{Float64}(I, 5, 3)
end

@testset "Helper functions" begin
rng = MersenneTwister(12345)
X = randn(rng, 10, 10)
Y = randn(rng, 10, 10)
@test Nabla.mulsubtrans!(copy(X), Y) Y .* (X - X')
@test Nabla.eyesubx!(copy(X)) I - X
@test Nabla.add!(copy(X), Y) X + Y
end
end

0 comments on commit 2026960

Please sign in to comment.