This repository has been archived by the owner on Apr 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement sensitivities for SVD (#131)
This supports `svd` and retrieving the `U`, `S`, or `V` property of the resulting `SVD` object.
- Loading branch information
Showing
4 changed files
with
160 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import LinearAlgebra: svd | ||
import Base: getproperty | ||
|
||
@explicit_intercepts svd Tuple{AbstractMatrix{<:Real}} | ||
|
||
∇(::typeof(svd), ::Type{Arg{1}}, p, USV::SVD, S̄::AbstractVector, A::AbstractMatrix) = | ||
svd_rev(USV, zeroslike(USV.U), S̄, zeroslike(USV.V)) | ||
∇(::typeof(svd), ::Type{Arg{1}}, p, USV::SVD, V̄::Adjoint, A::AbstractMatrix) = | ||
svd_rev(USV, zeroslike(USV.U), zeroslike(USV.S), V̄) | ||
∇(::typeof(svd), ::Type{Arg{1}}, p, USV::SVD, Ū::AbstractMatrix, A::AbstractMatrix) = | ||
svd_rev(USV, Ū, zeroslike(USV.S), zeroslike(USV.V)) | ||
|
||
@explicit_intercepts getproperty Tuple{SVD, Symbol} [true, false] | ||
|
||
function ∇(::typeof(getproperty), ::Type{Arg{1}}, p, y, ȳ, USV::SVD, x::Symbol) | ||
if x === :S | ||
return vec(ȳ) | ||
elseif x === :U | ||
return reshape(ȳ, size(USV.U)) | ||
elseif x === :V | ||
# This is so we can ensure that the result is an Adjoint, otherwise dispatch | ||
# won't work properly | ||
return copy(ȳ')' | ||
elseif x === :Vt | ||
throw(ArgumentError("Vt is unsupported; use V and transpose the result")) | ||
else | ||
throw(ArgumentError("unrecognized property $x; expected U, S, or V")) | ||
end | ||
end | ||
|
||
""" | ||
svd_rev(USV, Ū, S̄, V̄) | ||
Compute the reverse mode sensitivities of the singular value decomposition (SVD). `USV` is | ||
an `SVD` factorization object produced by a call to `svd`, and `Ū`, `S̄`, and `V̄` are the | ||
respective sensitivities of the `U`, `S`, and `V` factors. | ||
""" | ||
function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix) | ||
# Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default | ||
U = USV.U | ||
s = USV.S | ||
V = USV.V | ||
Vt = USV.Vt | ||
|
||
k = length(s) | ||
T = eltype(s) | ||
F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k] | ||
|
||
# We do a lot of matrix operations here, so we'll try to be memory-friendly and do | ||
# as many of the computations in-place as possible. Benchmarking shows that the in- | ||
# place functions here are significantly faster than their out-of-place, naively | ||
# implemented counterparts, and allocate no additional memory. | ||
Ut = U' | ||
FUᵀŪ = mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU) | ||
FVᵀV̄ = mulsubtrans!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV) | ||
ImUUᵀ = eyesubx!(U*Ut) # I - UUᵀ | ||
ImVVᵀ = eyesubx!(V*Vt) # I - VVᵀ | ||
|
||
S = Diagonal(s) | ||
S̄ = Diagonal(s̄) | ||
|
||
Ā = add!(U*FUᵀŪ*S, ImUUᵀ*(Ū/S))*Vt | ||
add!(Ā, U*S̄*Vt) | ||
add!(Ā, U*add!(S*FVᵀV̄*Vt, (S\V̄')*ImVVᵀ)) | ||
|
||
return Ā | ||
end | ||
|
||
""" | ||
mulsubtrans!(X::AbstractMatrix, F::AbstractMatrix) | ||
Compute `F .* (X - X')`, overwriting `X` in the process. | ||
!!! note | ||
This is an internal function that does no argument checking; the matrices passed to | ||
this function are square with matching dimensions by construction. | ||
""" | ||
function mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real | ||
k = size(X, 1) | ||
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle | ||
if i == j | ||
X[i,i] = zero(T) | ||
else | ||
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j]) | ||
end | ||
end | ||
X | ||
end | ||
|
||
""" | ||
eyesubx!(X::AbstractMatrix) | ||
Compute `I - X`, overwriting `X` in the process. | ||
""" | ||
function eyesubx!(X::AbstractMatrix{T}) where T<:Real | ||
n, m = size(X) | ||
@inbounds for j = 1:m, i = 1:n | ||
X[i,j] = (i == j) - X[i,j] | ||
end | ||
X | ||
end | ||
|
||
""" | ||
add!(X::AbstractMatrix, Y::AbstractMatrix) | ||
Compute `X + Y`, overwriting X in the process. | ||
!!! note | ||
This is an internal function that does no argument checking; the matrices passed to | ||
this function are square with matching dimensions by construction. | ||
""" | ||
function add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real | ||
@inbounds for i = eachindex(X, Y) | ||
X[i] += Y[i] | ||
end | ||
X | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
@testset "SVD" begin | ||
@testset "Comparison with finite differencing" begin | ||
rng = MersenneTwister(12345) | ||
for n in [4, 6, 10], m in [3, 5, 10] | ||
k = min(n, m) | ||
A = randn(rng, n, m) | ||
VA = randn(rng, n, m) | ||
@test check_errs(X->svd(X).U, randn(rng, n, k), A, VA) | ||
@test check_errs(X->svd(X).S, randn(rng, k), A, VA) | ||
@test check_errs(X->svd(X).V, randn(rng, m, k), A, VA) | ||
end | ||
end | ||
|
||
@testset "Error conditions" begin | ||
rng = MersenneTwister(12345) | ||
A = randn(rng, 5, 3) | ||
V̄t = randn(rng, 3, 3) | ||
@test_throws ArgumentError check_errs(X->svd(X).Vt, V̄t, A, A) | ||
@test_throws ErrorException check_errs(X->svd(X).whoops, V̄t, A, A) | ||
end | ||
|
||
@testset "Branch consistency" begin | ||
X_ = Matrix{Float64}(I, 3, 5) | ||
X = Leaf(Tape(), X_) | ||
USV = svd(X) | ||
@test USV isa Branch{<:SVD} | ||
@test getfield(USV, :f) == svd | ||
@test unbox(USV.U) ≈ Matrix{Float64}(I, 3, 3) | ||
@test unbox(USV.S) ≈ ones(Float64, 3) | ||
@test unbox(USV.V) ≈ Matrix{Float64}(I, 5, 3) | ||
end | ||
|
||
@testset "Helper functions" begin | ||
rng = MersenneTwister(12345) | ||
X = randn(rng, 10, 10) | ||
Y = randn(rng, 10, 10) | ||
@test Nabla.mulsubtrans!(copy(X), Y) ≈ Y .* (X - X') | ||
@test Nabla.eyesubx!(copy(X)) ≈ I - X | ||
@test Nabla.add!(copy(X), Y) ≈ X + Y | ||
end | ||
end |