From 8476cd3c8d21d4e710f72a887cd91637376071fa Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 19 May 2023 20:50:09 +0800 Subject: [PATCH 1/4] Unrelated change: inv is broken on GPU in 1.9 also --- test/rulesets/Base/arraymath.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 5eaf9e7fc..e0546cc01 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -1,7 +1,7 @@ @testset "arraymath.jl" begin @testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64) B = generate_well_conditioned_matrix(T, 3) - if VERSION >= v"1.7" + if v"1.7" <= VERSION < v"1.9" @gpu test_frule(inv, B) @gpu test_rrule(inv, B) else From c75d92054649d08cb7ed64143bc3302b4daf2813 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 19 May 2023 23:05:03 +0800 Subject: [PATCH 2/4] Fix / on 1.9 --- src/rulesets/Base/arraymath.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 3673c0a43..e1e2626c7 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -342,20 +342,24 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R project_B = ProjectTo(B) Y = A \ B + # Ever since https://github.com/JuliaLang/julia/pull/44358 + # we need to use `pinv` rather than `/` to support both the cases of Y being scalar and array + # See also https://github.com/JuliaLang/julia/issues/28827 which would improve this function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) + Ati = pinv(A') ∂A = @thunk begin - B̄ = A' \ Ȳ + + B̄ = Ati * Ȳ Ā = -B̄ * Y' - Ā = add!!(Ā, (B - A * Y) * B̄' / A') - Ā = add!!(Ā, A' \ Y * (Ȳ' - B̄'A)) + Ā = add!!(Ā, ((B - A * Y) * B̄') * Ati) + Ā = add!!(Ā, Ati * Y * (Ȳ' - B̄'A)) project_A(Ā) end - ∂B = @thunk project_B(A' \ Ȳ) + ∂B = @thunk project_B(Ati * Ȳ) return NoTangent(), ∂A, ∂B end return Y, backslash_pullback - end ##### From ccd4196d69f80e7fa6c7ca240d44cb09b45cc16a Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 25 May 2023 14:58:16 +0800 Subject: [PATCH 3/4] just arrayify scalar, and also prefactorize A' --- src/rulesets/Base/arraymath.jl | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index e1e2626c7..594a4bfe9 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -342,21 +342,24 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R project_B = ProjectTo(B) Y = A \ B - # Ever since https://github.com/JuliaLang/julia/pull/44358 - # we need to use `pinv` rather than `/` to support both the cases of Y being scalar and array - # See also https://github.com/JuliaLang/julia/issues/28827 which would improve this + + Atf = factorize(A') + function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) - Ati = pinv(A') + @static if VERSION >= v"1.9" + # Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358 + Ȳ isa AbstractArray || Ȳ = [Ȳ] + end + Atf = factorize(A') ∂A = @thunk begin - - B̄ = Ati * Ȳ + B̄ = Atf \ Ȳ Ā = -B̄ * Y' - Ā = add!!(Ā, ((B - A * Y) * B̄') * Ati) - Ā = add!!(Ā, Ati * Y * (Ȳ' - B̄'A)) + Ā = add!!(Ā, ((B - A * Y) * B̄') / Atf) + Ā = add!!(Ā, Atf \ Y * (Ȳ' - B̄'A)) project_A(Ā) end - ∂B = @thunk project_B(Ati * Ȳ) + ∂B = @thunk project_B(Atf \ Ȳ) return NoTangent(), ∂A, ∂B end return Y, backslash_pullback From fb8cacd819a26b9c7ad656ec45710ffbb08624c6 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 2 Jun 2023 17:21:18 +0800 Subject: [PATCH 4/4] only do minimal change to rule for \ to convert to array Also make second Y not scalar more coercing some things into arrays some of the time cleaner def with a helper function --- src/rulesets/Base/arraymath.jl | 59 ++++++++++++++++++++++++++++----- test/rulesets/Base/arraymath.jl | 4 +-- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 594a4bfe9..85fd8df51 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -343,28 +343,71 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R Y = A \ B - Atf = factorize(A') - function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) + + Ȳf = Ȳ @static if VERSION >= v"1.9" # Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358 - Ȳ isa AbstractArray || Ȳ = [Ȳ] + if !isa(Ȳ, AbstractArray) + Ȳf = [Ȳ] + end + end + Yf = Y + @static if VERSION >= v"1.9" + # Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358 + if !isa(Y, AbstractArray) + Yf = [Y] + end end - Atf = factorize(A') + #@info "vars" typeof(Ȳ) typeof(Y) typeof(Yf) typeof(A) typeof(B) ∂A = @thunk begin - B̄ = Atf \ Ȳ + B̄ = A' \ Ȳf Ā = -B̄ * Y' - Ā = add!!(Ā, ((B - A * Y) * B̄') / Atf) - Ā = add!!(Ā, Atf \ Y * (Ȳ' - B̄'A)) + t = (B - A * Y) * B̄' + @static if VERSION >= v"1.9" + # Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358 + if !isa(t, AbstractArray) + t = [t] + end + end + Ā = add!!(Ā, t / A') + Ā = add!!(Ā, A' \ Yf * (Ȳ' - B̄'A)) project_A(Ā) end - ∂B = @thunk project_B(Atf \ Ȳ) + ∂B = @thunk project_B(A' \ Ȳf) return NoTangent(), ∂A, ∂B end return Y, backslash_pullback end +@static if VERSION >= v"1.9" + # Need to ensure things are not scalar since since https://github.com/JuliaLang/julia/pull/44358 + _maybe_descalar(x) = x isa AbstractArray ? x : [x] +else + _maybe_descalar(x) = x +end + +function rrule(A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) + Y = A \ B + + + function backslash_pullback(ȳ) + Ȳ = unthunk(ȳ) + + ∂A = @thunk begin + B̄ = A' \ _maybe_descalar(Ȳ) + Ā = -B̄ * Y' + Ā += _maybe_descalar((B - A * Y) * B̄') / A' + Ā += (A' \ _maybe_descalar(Y)) * (Ȳ' - B̄'A) + (Ā) + end + ∂B = @thunk (A' \ _maybe_descalar(Ȳ)) + return ∂A, ∂B + end + return Y, backslash_pullback +end + ##### ##### `\`, `/` matrix-scalar_rule ##### diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index e0546cc01..847808c1f 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -167,12 +167,12 @@ @testset "Matrix $f Vector" begin X = randn(10, 4) y = randn(10) - test_rrule(f, X, y) + test_rrule(f, X, y; check_inferred=false) end @testset "Vector $f Matrix" begin x = randn(10) Y = randn(10, 4) - test_rrule(f, x, Y; output_tangent=Transpose(rand(4))) + test_rrule(f, x, Y; output_tangent=Transpose(rand(4)), check_inferred=false) end else A = rand(2, 4)