From 78203db57d15f4f014d62b0ab9cccfe346c9f064 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Tue, 11 Jul 2023 11:38:46 -0700 Subject: [PATCH] Refactor SchurComplementW for ClimaTimesteppers --- examples/hybrid/schur_complement_W.jl | 187 +++++++++++--------- test/Operators/finitedifference/linsolve.jl | 8 +- 2 files changed, 109 insertions(+), 86 deletions(-) diff --git a/examples/hybrid/schur_complement_W.jl b/examples/hybrid/schur_complement_W.jl index 5e211dd1ba..3d40c8485e 100644 --- a/examples/hybrid/schur_complement_W.jl +++ b/examples/hybrid/schur_complement_W.jl @@ -6,7 +6,7 @@ using ClimaCore.Utilities: half const compose = Operators.ComposeStencils() const apply = Operators.ApplyStencil() -struct SchurComplementW{F, FT, J1, J2, J3, J4, S} +struct SchurComplementW{F, FT, J1, J2, J3, J4, S, T} # whether this struct is used to compute Wfact_t or Wfact transform::Bool @@ -28,6 +28,10 @@ struct SchurComplementW{F, FT, J1, J2, J3, J4, S} # whether to test the Jacobian and linear solver test::Bool + + # cache that is used to evaluate ldiv! + temp1::T + temp2::T end function SchurComplementW(Y, transform, flags, test = false) @@ -61,6 +65,7 @@ function SchurComplementW(Y, transform, flags, test = false) typeof(โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ), typeof(โˆ‚แถ ๐•„โ‚œโˆ‚แถ ๐•„), typeof(S), + typeof(Y), }( transform, flags, @@ -72,6 +77,8 @@ function SchurComplementW(Y, transform, flags, test = false) โˆ‚แถ ๐•„โ‚œโˆ‚แถ ๐•„, S, test, + similar(Y), + similar(Y), ) end @@ -101,91 +108,107 @@ Finally, use (1) and (2) to get x1 and x2. Note: The matrix S = A31 A13 + A32 A23 + A33 - I is the "Schur complement" of [-I 0; 0 -I] (the top-left 4 blocks) in A. =# -function linsolve!(::Type{Val{:init}}, f, u0; kwargs...) - function _linsolve!(x, A, b, update_matrix = false; kwargs...) - (; dtฮณ_ref, โˆ‚แถœฯโ‚œโˆ‚แถ ๐•„, โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„, โˆ‚แถ ๐•„โ‚œโˆ‚แถœ๐”ผ, โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ, โˆ‚แถ ๐•„โ‚œโˆ‚แถ ๐•„) = A - (; S) = A - dtฮณ = dtฮณ_ref[] - - xแถœฯ = x.c.ฯ - bแถœฯ = b.c.ฯ - if :ฯฮธ in propertynames(x.c) - xแถœ๐”ผ = x.c.ฯฮธ - bแถœ๐”ผ = b.c.ฯฮธ - elseif :ฯe in propertynames(x.c) - xแถœ๐”ผ = x.c.ฯe - bแถœ๐”ผ = b.c.ฯe - elseif :ฯe_int in propertynames(x.c) - xแถœ๐”ผ = x.c.ฯe_int - bแถœ๐”ผ = b.c.ฯe_int - end - if :ฯw in propertynames(x.f) - xแถ ๐•„ = x.f.ฯw.components.data.:1 - bแถ ๐•„ = b.f.ฯw.components.data.:1 - elseif :w in propertynames(x.f) - xแถ ๐•„ = x.f.w.components.data.:1 - bแถ ๐•„ = b.f.w.components.data.:1 - end +# Function required by OrdinaryDiffEq.jl +linsolve!(::Type{Val{:init}}, f, u0; kwargs...) = _linsolve! +_linsolve!(x, A, b, update_matrix = false; kwargs...) = + LinearAlgebra.ldiv!(x, A, b) + +# Function required by Krylov.jl (x and b can be AbstractVectors) +# See https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/605 for a +# related issue that requires the same workaround. +function LinearAlgebra.ldiv!(x, A::SchurComplementW, b) + A.temp1 .= b + LinearAlgebra.ldiv!(A.temp2, A, A.temp1) + x .= A.temp2 +end - # TODO: Extend LinearAlgebra.I to work with stencil fields. - FT = eltype(eltype(S)) - I = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT)))) - if Operators.bandwidths(eltype(โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„)) != (-half, half) - str = "The linear solver cannot yet be run with the given โˆ‚แถœ๐”ผโ‚œ/โˆ‚แถ ๐•„ \ - block, since it has more than 2 diagonals. So, โˆ‚แถœ๐”ผโ‚œ/โˆ‚แถ ๐•„ will \ - be set to 0 for the Schur complement computation. Consider \ - changing the โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„_mode or the energy variable." - @warn str maxlog = 1 - @. S = dtฮณ^2 * compose(โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ, โˆ‚แถœฯโ‚œโˆ‚แถ ๐•„) + dtฮณ * โˆ‚แถ ๐•„โ‚œโˆ‚แถ ๐•„ - I - else - @. S = - dtฮณ^2 * compose(โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ, โˆ‚แถœฯโ‚œโˆ‚แถ ๐•„) + - dtฮณ^2 * compose(โˆ‚แถ ๐•„โ‚œโˆ‚แถœ๐”ผ, โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„) + - dtฮณ * โˆ‚แถ ๐•„โ‚œโˆ‚แถ ๐•„ - I - end +function LinearAlgebra.ldiv!( + x::Fields.FieldVector, + A::SchurComplementW, + b::Fields.FieldVector, +) + (; dtฮณ_ref, โˆ‚แถœฯโ‚œโˆ‚แถ ๐•„, โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„, โˆ‚แถ ๐•„โ‚œโˆ‚แถœ๐”ผ, โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ, โˆ‚แถ ๐•„โ‚œโˆ‚แถ ๐•„) = A + (; S) = A + dtฮณ = dtฮณ_ref[] + + xแถœฯ = x.c.ฯ + bแถœฯ = b.c.ฯ + if :ฯฮธ in propertynames(x.c) + xแถœ๐”ผ = x.c.ฯฮธ + bแถœ๐”ผ = b.c.ฯฮธ + elseif :ฯe in propertynames(x.c) + xแถœ๐”ผ = x.c.ฯe + bแถœ๐”ผ = b.c.ฯe + elseif :ฯe_int in propertynames(x.c) + xแถœ๐”ผ = x.c.ฯe_int + bแถœ๐”ผ = b.c.ฯe_int + end + if :ฯw in propertynames(x.f) + xแถ ๐•„ = x.f.ฯw.components.data.:1 + bแถ ๐•„ = b.f.ฯw.components.data.:1 + elseif :w in propertynames(x.f) + xแถ ๐•„ = x.f.w.components.data.:1 + bแถ ๐•„ = b.f.w.components.data.:1 + end - @. xแถ ๐•„ = bแถ ๐•„ + dtฮณ * (apply(โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ, bแถœฯ) + apply(โˆ‚แถ ๐•„โ‚œโˆ‚แถœ๐”ผ, bแถœ๐”ผ)) - - Operators.column_thomas_solve!(S, xแถ ๐•„) - - @. xแถœฯ = -bแถœฯ + dtฮณ * apply(โˆ‚แถœฯโ‚œโˆ‚แถ ๐•„, xแถ ๐•„) - @. xแถœ๐”ผ = -bแถœ๐”ผ + dtฮณ * apply(โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„, xแถ ๐•„) - - if A.test && Operators.bandwidths(eltype(โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„)) == (-half, half) - Ni, Nj, _, Nv, Nh = size(Spaces.local_geometry_data(axes(xแถœฯ))) - โˆ‚Yโ‚œโˆ‚Y = Array{FT}(undef, 3 * Nv + 1, 3 * Nv + 1) - ฮ”Y = Array{FT}(undef, 3 * Nv + 1) - ฮ”ฮ”Y = Array{FT}(undef, 3 * Nv + 1) - for h in 1:Nh, j in 1:Nj, i in 1:Ni - โˆ‚Yโ‚œโˆ‚Y .= zero(FT) - โˆ‚Yโ‚œโˆ‚Y[1:Nv, (2 * Nv + 1):(3 * Nv + 1)] .= - matrix_column(โˆ‚แถœฯโ‚œโˆ‚แถ ๐•„, axes(x.f), i, j, h) - โˆ‚Yโ‚œโˆ‚Y[(Nv + 1):(2 * Nv), (2 * Nv + 1):(3 * Nv + 1)] .= - matrix_column(โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„, axes(x.f), i, j, h) - โˆ‚Yโ‚œโˆ‚Y[(2 * Nv + 1):(3 * Nv + 1), 1:Nv] .= - matrix_column(โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ, axes(x.c), i, j, h) - โˆ‚Yโ‚œโˆ‚Y[(2 * Nv + 1):(3 * Nv + 1), (Nv + 1):(2 * Nv)] .= - matrix_column(โˆ‚แถ ๐•„โ‚œโˆ‚แถœ๐”ผ, axes(x.c), i, j, h) - โˆ‚Yโ‚œโˆ‚Y[(2 * Nv + 1):(3 * Nv + 1), (2 * Nv + 1):(3 * Nv + 1)] .= - matrix_column(โˆ‚แถ ๐•„โ‚œโˆ‚แถ ๐•„, axes(x.f), i, j, h) - ฮ”Y[1:Nv] .= vector_column(xแถœฯ, i, j, h) - ฮ”Y[(Nv + 1):(2 * Nv)] .= vector_column(xแถœ๐”ผ, i, j, h) - ฮ”Y[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(xแถ ๐•„, i, j, h) - ฮ”ฮ”Y[1:Nv] .= vector_column(bแถœฯ, i, j, h) - ฮ”ฮ”Y[(Nv + 1):(2 * Nv)] .= vector_column(bแถœ๐”ผ, i, j, h) - ฮ”ฮ”Y[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(bแถ ๐•„, i, j, h) - @assert (-LinearAlgebra.I + dtฮณ * โˆ‚Yโ‚œโˆ‚Y) * ฮ”Y โ‰ˆ ฮ”ฮ”Y - end - end + # TODO: Extend LinearAlgebra.I to work with stencil fields. + FT = eltype(eltype(S)) + I = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT)))) + if Operators.bandwidths(eltype(โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„)) != (-half, half) + str = "The linear solver cannot yet be run with the given โˆ‚แถœ๐”ผโ‚œ/โˆ‚แถ ๐•„ \ + block, since it has more than 2 diagonals. So, โˆ‚แถœ๐”ผโ‚œ/โˆ‚แถ ๐•„ will \ + be set to 0 for the Schur complement computation. Consider \ + changing the โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„_mode or the energy variable." + @warn str maxlog = 1 + @. S = dtฮณ^2 * compose(โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ, โˆ‚แถœฯโ‚œโˆ‚แถ ๐•„) + dtฮณ * โˆ‚แถ ๐•„โ‚œโˆ‚แถ ๐•„ - I + else + @. S = + dtฮณ^2 * compose(โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ, โˆ‚แถœฯโ‚œโˆ‚แถ ๐•„) + + dtฮณ^2 * compose(โˆ‚แถ ๐•„โ‚œโˆ‚แถœ๐”ผ, โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„) + + dtฮณ * โˆ‚แถ ๐•„โ‚œโˆ‚แถ ๐•„ - I + end - if :ฯuโ‚• in propertynames(x.c) - @. x.c.ฯuโ‚• = -b.c.ฯuโ‚• - elseif :uโ‚• in propertynames(x.c) - @. x.c.uโ‚• = -b.c.uโ‚• + @. xแถ ๐•„ = bแถ ๐•„ + dtฮณ * (apply(โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ, bแถœฯ) + apply(โˆ‚แถ ๐•„โ‚œโˆ‚แถœ๐”ผ, bแถœ๐”ผ)) + + Operators.column_thomas_solve!(S, xแถ ๐•„) + + @. xแถœฯ = -bแถœฯ + dtฮณ * apply(โˆ‚แถœฯโ‚œโˆ‚แถ ๐•„, xแถ ๐•„) + @. xแถœ๐”ผ = -bแถœ๐”ผ + dtฮณ * apply(โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„, xแถ ๐•„) + + if A.test && Operators.bandwidths(eltype(โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„)) == (-half, half) + Ni, Nj, _, Nv, Nh = size(Spaces.local_geometry_data(axes(xแถœฯ))) + โˆ‚Yโ‚œโˆ‚Y = Array{FT}(undef, 3 * Nv + 1, 3 * Nv + 1) + ฮ”Y = Array{FT}(undef, 3 * Nv + 1) + ฮ”ฮ”Y = Array{FT}(undef, 3 * Nv + 1) + for h in 1:Nh, j in 1:Nj, i in 1:Ni + โˆ‚Yโ‚œโˆ‚Y .= zero(FT) + โˆ‚Yโ‚œโˆ‚Y[1:Nv, (2 * Nv + 1):(3 * Nv + 1)] .= + matrix_column(โˆ‚แถœฯโ‚œโˆ‚แถ ๐•„, axes(x.f), i, j, h) + โˆ‚Yโ‚œโˆ‚Y[(Nv + 1):(2 * Nv), (2 * Nv + 1):(3 * Nv + 1)] .= + matrix_column(โˆ‚แถœ๐”ผโ‚œโˆ‚แถ ๐•„, axes(x.f), i, j, h) + โˆ‚Yโ‚œโˆ‚Y[(2 * Nv + 1):(3 * Nv + 1), 1:Nv] .= + matrix_column(โˆ‚แถ ๐•„โ‚œโˆ‚แถœฯ, axes(x.c), i, j, h) + โˆ‚Yโ‚œโˆ‚Y[(2 * Nv + 1):(3 * Nv + 1), (Nv + 1):(2 * Nv)] .= + matrix_column(โˆ‚แถ ๐•„โ‚œโˆ‚แถœ๐”ผ, axes(x.c), i, j, h) + โˆ‚Yโ‚œโˆ‚Y[(2 * Nv + 1):(3 * Nv + 1), (2 * Nv + 1):(3 * Nv + 1)] .= + matrix_column(โˆ‚แถ ๐•„โ‚œโˆ‚แถ ๐•„, axes(x.f), i, j, h) + ฮ”Y[1:Nv] .= vector_column(xแถœฯ, i, j, h) + ฮ”Y[(Nv + 1):(2 * Nv)] .= vector_column(xแถœ๐”ผ, i, j, h) + ฮ”Y[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(xแถ ๐•„, i, j, h) + ฮ”ฮ”Y[1:Nv] .= vector_column(bแถœฯ, i, j, h) + ฮ”ฮ”Y[(Nv + 1):(2 * Nv)] .= vector_column(bแถœ๐”ผ, i, j, h) + ฮ”ฮ”Y[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(bแถ ๐•„, i, j, h) + @assert (-LinearAlgebra.I + dtฮณ * โˆ‚Yโ‚œโˆ‚Y) * ฮ”Y โ‰ˆ ฮ”ฮ”Y end + end - if A.transform - x .*= dtฮณ - end + if :ฯuโ‚• in propertynames(x.c) + @. x.c.ฯuโ‚• = -b.c.ฯuโ‚• + elseif :uโ‚• in propertynames(x.c) + @. x.c.uโ‚• = -b.c.uโ‚• + end + + if A.transform + x .*= dtฮณ end end diff --git a/test/Operators/finitedifference/linsolve.jl b/test/Operators/finitedifference/linsolve.jl index cf6041391b..5ab3e6549d 100644 --- a/test/Operators/finitedifference/linsolve.jl +++ b/test/Operators/finitedifference/linsolve.jl @@ -35,7 +35,7 @@ face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space) =# face_space = Spaces.FaceFiniteDifferenceSpace(center_space) -function _linsolve!(x, A, b, update_matrix = false; kwargs...) +function test_linsolve!(x, A, b, update_matrix = false; kwargs...) FT = Spaces.undertype(axes(x.c)) @@ -88,11 +88,11 @@ W = SchurComplementW(Y, use_transform, jacobi_flags) using JET using Test -@time _linsolve!(Y, W, b) -@time _linsolve!(Y, W, b) +@time test_linsolve!(Y, W, b) +@time test_linsolve!(Y, W, b) @testset "JET test for `apply` in linsolve! kernel" begin - @test_opt _linsolve!(Y, W, b) + @test_opt test_linsolve!(Y, W, b) end ClimaCore.Operators.allow_mismatched_fd_spaces() = false