From 0d465639d847ae73b971c667b63443bf6dcbc29b Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 27 Sep 2024 00:31:22 +0200 Subject: [PATCH] symv (#269) * Bump patch version * symv rule implementation * Fix comments --- Project.toml | 2 +- src/rrules/blas.jl | 106 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 104 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 44f1074a..919cb0c0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.2" +version = "0.4.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 98604bd5..0fd47ac2 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -19,7 +19,7 @@ function tri!(A, u::Char, d::Char) return u == 'L' ? tril!(A, d == 'U' ? -1 : 0) : triu!(A, d == 'U' ? 1 : 0) end - +const MatrixOrView{T} = Union{Matrix{T}, SubArray{T, 2, Matrix{T}}} # # LEVEL 1 @@ -198,6 +198,91 @@ for (gemv, elty) in ((:dgemv_, :Float64), (:sgemv_, :Float32)) end end +@is_primitive( + MinimalCtx, + Tuple{ + typeof(BLAS.symv!), + Char, + T, + MatrixOrView{T}, + Vector{T}, + T, + Vector{T}, + } where {T<:Union{Float32, Float64}}, +) + +function rrule!!( + ::CoDual{typeof(BLAS.symv!)}, + uplo::CoDual{Char}, + alpha::CoDual{T}, + A_dA::CoDual{<:MatrixOrView{T}}, + x_dx::CoDual{Vector{T}}, + beta::CoDual{T}, + y_dy::CoDual{Vector{T}}, +) where {T<:Union{Float32, Float64}} + + # Extract primals. + ul = primal(uplo) + α = primal(alpha) + β = primal(beta) + A, dA = viewify(A_dA) + x, dx = viewify(x_dx) + y, dy = viewify(y_dy) + + # In this rule we optimise carefully for the special case a == 1 && b == 0, which + # corresponds to simply multiplying symm(A) and x together, and writing the result to y. + # This is an extremely common edge case, so it's important to do well for it. + y_copy = copy(y) + tmp_ref = Ref{Vector{T}}() + if (α == 1 && β == 0) + BLAS.symv!(ul, α, A, x, β, y) + else + tmp = BLAS.symv(ul, one(T), A, x) + tmp_ref[] = tmp + BLAS.axpby!(α, tmp, β, y) + end + + function symv!_adjoint(::NoRData) + + if (α == 1 && β == 0) + dα = dot(dy, y) + BLAS.copyto!(y, y_copy) + else + # Reset y. + BLAS.copyto!(y, y_copy) + + # gradient w.r.t. α. Safe to write into memory for copy of y. + BLAS.symv!(ul, one(T), A, x, zero(T), y_copy) + dα = dot(dy, y_copy) + end + + # gradient w.r.t. A. + dA_tmp = dy * x' + if ul == 'L' + dA .+= α .* LowerTriangular(dA_tmp) + dA .+= α .* UpperTriangular(dA_tmp)' + else + dA .+= α .* LowerTriangular(dA_tmp)' + dA .+= α .* UpperTriangular(dA_tmp) + end + @inbounds for n in diagind(dA) + dA[n] -= α * dA_tmp[n] + end + + # gradient w.r.t. x. + BLAS.symv!(ul, α, A, dy, one(T), dx) + + # gradient w.r.t. beta. + dβ = dot(dy, y) + + # gradient w.r.t. y. + BLAS.scal!(β, dy) + + return NoRData(), NoRData(), dα, NoRData(), NoRData(), dβ, NoRData() + end + return y_dy, symv!_adjoint +end + for (trmv, elty) in ((:dtrmv_, :Float64), (:strmv_, :Float32)) @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, @@ -260,8 +345,6 @@ end # LEVEL 3 # -const MatrixOrView{T} = Union{Matrix{T}, SubArray{T, 2, Matrix{T}}} - @is_primitive( MinimalCtx, Tuple{ @@ -336,6 +419,7 @@ function rrule!!( return C, gemm!_pb!! end +viewify(A::CoDual{<:Vector}) = primal(A), tangent(A) viewify(A::CoDual{<:Matrix}) = view(primal(A), :, :), view(tangent(A), :, :) function viewify(A::CoDual{P}) where {P<:SubArray} p_A = primal(A) @@ -742,6 +826,22 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) betas = [0.0, 0.33] test_cases = vcat( + + # symv! + vec(reduce( + vcat, + vec(map(product(['L', 'U'], alphas, betas)) do (uplo, α, β) + A = randn(5, 5) + vA = view(randn(15, 15), 1:5, 1:5) + x = randn(5) + y = randn(5) + return Any[ + (false, :stability, nothing, BLAS.symv!, uplo, α, A, x, β, y), + (false, :stability, nothing, BLAS.symv!, uplo, α, vA, x, β, y), + ] + end) + )), + # gemm! vec(reduce( vcat,