Skip to content

Commit

Permalink
Merge #1364
Browse files Browse the repository at this point in the history
1364: Refactor SchurComplementW for ClimaTimesteppers r=charleskawczynski a=charleskawczynski

This PR is a peel off from #1358.

This PR adds `temp1` and `temp2` to `SchurComplementW`, and defines
```julia
linsolve!(::Type{Val{:init}}, f, u0; kwargs...) = _linsolve!
_linsolve!(x, A, b, update_matrix = false; kwargs...) =
    LinearAlgebra.ldiv!(x, A, b)

# Function required by Krylov.jl (x and b can be AbstractVectors)
# See JuliaSmoothOptimizers/Krylov.jl#605 for a
# related issue that requires the same workaround.
function LinearAlgebra.ldiv!(x, A::SchurComplementW, b)
    A.temp1 .= b
    LinearAlgebra.ldiv!(A.temp2, A, A.temp1)
    x .= A.temp2
end
```
and the original `_linsolve` contents are in
```julia
function LinearAlgebra.ldiv!(
    x::Fields.FieldVector,
    A::SchurComplementW,
    b::Fields.FieldVector,
)
```

(the pattern used in ClimaAtmos). It also renames `_linsolve!` to `test_linsolve!` to avoid a name collision in the test suite.

It's a much smaller PR than it appears, due to indenting.

Co-authored-by: Charles Kawczynski <[email protected]>
  • Loading branch information
bors[bot] and charleskawczynski authored Jul 14, 2023
2 parents 445448e + 78203db commit 22534ef
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 86 deletions.
187 changes: 105 additions & 82 deletions examples/hybrid/schur_complement_W.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ClimaCore.Utilities: half
const compose = Operators.ComposeStencils()
const apply = Operators.ApplyStencil()

struct SchurComplementW{F, FT, J1, J2, J3, J4, S}
struct SchurComplementW{F, FT, J1, J2, J3, J4, S, T}
# whether this struct is used to compute Wfact_t or Wfact
transform::Bool

Expand All @@ -28,6 +28,10 @@ struct SchurComplementW{F, FT, J1, J2, J3, J4, S}

# whether to test the Jacobian and linear solver
test::Bool

# cache that is used to evaluate ldiv!
temp1::T
temp2::T
end

function SchurComplementW(Y, transform, flags, test = false)
Expand Down Expand Up @@ -61,6 +65,7 @@ function SchurComplementW(Y, transform, flags, test = false)
typeof(∂ᶠ𝕄ₜ∂ᶜρ),
typeof(∂ᶠ𝕄ₜ∂ᶠ𝕄),
typeof(S),
typeof(Y),
}(
transform,
flags,
Expand All @@ -72,6 +77,8 @@ function SchurComplementW(Y, transform, flags, test = false)
∂ᶠ𝕄ₜ∂ᶠ𝕄,
S,
test,
similar(Y),
similar(Y),
)
end

Expand Down Expand Up @@ -101,91 +108,107 @@ Finally, use (1) and (2) to get x1 and x2.
Note: The matrix S = A31 A13 + A32 A23 + A33 - I is the "Schur complement" of
[-I 0; 0 -I] (the top-left 4 blocks) in A.
=#
function linsolve!(::Type{Val{:init}}, f, u0; kwargs...)
function _linsolve!(x, A, b, update_matrix = false; kwargs...)
(; dtγ_ref, ∂ᶜρₜ∂ᶠ𝕄, ∂ᶜ𝔼ₜ∂ᶠ𝕄, ∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶠ𝕄ₜ∂ᶠ𝕄) = A
(; S) = A
dtγ = dtγ_ref[]

xᶜρ = x.c.ρ
bᶜρ = b.c.ρ
if :ρθ in propertynames(x.c)
xᶜ𝔼 = x.c.ρθ
bᶜ𝔼 = b.c.ρθ
elseif :ρe in propertynames(x.c)
xᶜ𝔼 = x.c.ρe
bᶜ𝔼 = b.c.ρe
elseif :ρe_int in propertynames(x.c)
xᶜ𝔼 = x.c.ρe_int
bᶜ𝔼 = b.c.ρe_int
end
if :ρw in propertynames(x.f)
xᶠ𝕄 = x.f.ρw.components.data.:1
bᶠ𝕄 = b.f.ρw.components.data.:1
elseif :w in propertynames(x.f)
xᶠ𝕄 = x.f.w.components.data.:1
bᶠ𝕄 = b.f.w.components.data.:1
end
# Function required by OrdinaryDiffEq.jl
linsolve!(::Type{Val{:init}}, f, u0; kwargs...) = _linsolve!
_linsolve!(x, A, b, update_matrix = false; kwargs...) =
LinearAlgebra.ldiv!(x, A, b)

# Function required by Krylov.jl (x and b can be AbstractVectors)
# See https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/605 for a
# related issue that requires the same workaround.
function LinearAlgebra.ldiv!(x, A::SchurComplementW, b)
A.temp1 .= b
LinearAlgebra.ldiv!(A.temp2, A, A.temp1)
x .= A.temp2
end

# TODO: Extend LinearAlgebra.I to work with stencil fields.
FT = eltype(eltype(S))
I = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT))))
if Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) != (-half, half)
str = "The linear solver cannot yet be run with the given ∂ᶜ𝔼ₜ/∂ᶠ𝕄 \
block, since it has more than 2 diagonals. So, ∂ᶜ𝔼ₜ/∂ᶠ𝕄 will \
be set to 0 for the Schur complement computation. Consider \
changing the ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode or the energy variable."
@warn str maxlog = 1
@. S = dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) + dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I
else
@. S =
dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) +
dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶜ𝔼ₜ∂ᶠ𝕄) +
dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I
end
function LinearAlgebra.ldiv!(
x::Fields.FieldVector,
A::SchurComplementW,
b::Fields.FieldVector,
)
(; dtγ_ref, ∂ᶜρₜ∂ᶠ𝕄, ∂ᶜ𝔼ₜ∂ᶠ𝕄, ∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶠ𝕄ₜ∂ᶠ𝕄) = A
(; S) = A
dtγ = dtγ_ref[]

xᶜρ = x.c.ρ
bᶜρ = b.c.ρ
if :ρθ in propertynames(x.c)
xᶜ𝔼 = x.c.ρθ
bᶜ𝔼 = b.c.ρθ
elseif :ρe in propertynames(x.c)
xᶜ𝔼 = x.c.ρe
bᶜ𝔼 = b.c.ρe
elseif :ρe_int in propertynames(x.c)
xᶜ𝔼 = x.c.ρe_int
bᶜ𝔼 = b.c.ρe_int
end
if :ρw in propertynames(x.f)
xᶠ𝕄 = x.f.ρw.components.data.:1
bᶠ𝕄 = b.f.ρw.components.data.:1
elseif :w in propertynames(x.f)
xᶠ𝕄 = x.f.w.components.data.:1
bᶠ𝕄 = b.f.w.components.data.:1
end

@. xᶠ𝕄 = bᶠ𝕄 + dtγ * (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ) + apply(∂ᶠ𝕄ₜ∂ᶜ𝔼, bᶜ𝔼))

Operators.column_thomas_solve!(S, xᶠ𝕄)

@. xᶜρ = -bᶜρ + dtγ * apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄)
@. xᶜ𝔼 = -bᶜ𝔼 + dtγ * apply(∂ᶜ𝔼ₜ∂ᶠ𝕄, xᶠ𝕄)

if A.test && Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) == (-half, half)
Ni, Nj, _, Nv, Nh = size(Spaces.local_geometry_data(axes(xᶜρ)))
∂Yₜ∂Y = Array{FT}(undef, 3 * Nv + 1, 3 * Nv + 1)
ΔY = Array{FT}(undef, 3 * Nv + 1)
ΔΔY = Array{FT}(undef, 3 * Nv + 1)
for h in 1:Nh, j in 1:Nj, i in 1:Ni
∂Yₜ∂Y .= zero(FT)
∂Yₜ∂Y[1:Nv, (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶜρₜ∂ᶠ𝕄, axes(x.f), i, j, h)
∂Yₜ∂Y[(Nv + 1):(2 * Nv), (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶜ𝔼ₜ∂ᶠ𝕄, axes(x.f), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), 1:Nv] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶜρ, axes(x.c), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (Nv + 1):(2 * Nv)] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶜ𝔼, axes(x.c), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶠ𝕄, axes(x.f), i, j, h)
ΔY[1:Nv] .= vector_column(xᶜρ, i, j, h)
ΔY[(Nv + 1):(2 * Nv)] .= vector_column(xᶜ𝔼, i, j, h)
ΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(xᶠ𝕄, i, j, h)
ΔΔY[1:Nv] .= vector_column(bᶜρ, i, j, h)
ΔΔY[(Nv + 1):(2 * Nv)] .= vector_column(bᶜ𝔼, i, j, h)
ΔΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(bᶠ𝕄, i, j, h)
@assert (-LinearAlgebra.I + dtγ * ∂Yₜ∂Y) * ΔY ΔΔY
end
end
# TODO: Extend LinearAlgebra.I to work with stencil fields.
FT = eltype(eltype(S))
I = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT))))
if Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) != (-half, half)
str = "The linear solver cannot yet be run with the given ∂ᶜ𝔼ₜ/∂ᶠ𝕄 \
block, since it has more than 2 diagonals. So, ∂ᶜ𝔼ₜ/∂ᶠ𝕄 will \
be set to 0 for the Schur complement computation. Consider \
changing the ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode or the energy variable."
@warn str maxlog = 1
@. S = dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) + dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I
else
@. S =
dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) +
dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶜ𝔼ₜ∂ᶠ𝕄) +
dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I
end

if :ρuₕ in propertynames(x.c)
@. x.c.ρuₕ = -b.c.ρuₕ
elseif :uₕ in propertynames(x.c)
@. x.c.uₕ = -b.c.uₕ
@. xᶠ𝕄 = bᶠ𝕄 + dtγ * (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ) + apply(∂ᶠ𝕄ₜ∂ᶜ𝔼, bᶜ𝔼))

Operators.column_thomas_solve!(S, xᶠ𝕄)

@. xᶜρ = -bᶜρ + dtγ * apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄)
@. xᶜ𝔼 = -bᶜ𝔼 + dtγ * apply(∂ᶜ𝔼ₜ∂ᶠ𝕄, xᶠ𝕄)

if A.test && Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) == (-half, half)
Ni, Nj, _, Nv, Nh = size(Spaces.local_geometry_data(axes(xᶜρ)))
∂Yₜ∂Y = Array{FT}(undef, 3 * Nv + 1, 3 * Nv + 1)
ΔY = Array{FT}(undef, 3 * Nv + 1)
ΔΔY = Array{FT}(undef, 3 * Nv + 1)
for h in 1:Nh, j in 1:Nj, i in 1:Ni
∂Yₜ∂Y .= zero(FT)
∂Yₜ∂Y[1:Nv, (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶜρₜ∂ᶠ𝕄, axes(x.f), i, j, h)
∂Yₜ∂Y[(Nv + 1):(2 * Nv), (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶜ𝔼ₜ∂ᶠ𝕄, axes(x.f), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), 1:Nv] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶜρ, axes(x.c), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (Nv + 1):(2 * Nv)] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶜ𝔼, axes(x.c), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶠ𝕄, axes(x.f), i, j, h)
ΔY[1:Nv] .= vector_column(xᶜρ, i, j, h)
ΔY[(Nv + 1):(2 * Nv)] .= vector_column(xᶜ𝔼, i, j, h)
ΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(xᶠ𝕄, i, j, h)
ΔΔY[1:Nv] .= vector_column(bᶜρ, i, j, h)
ΔΔY[(Nv + 1):(2 * Nv)] .= vector_column(bᶜ𝔼, i, j, h)
ΔΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(bᶠ𝕄, i, j, h)
@assert (-LinearAlgebra.I + dtγ * ∂Yₜ∂Y) * ΔY ΔΔY
end
end

if A.transform
x .*= dtγ
end
if :ρuₕ in propertynames(x.c)
@. x.c.ρuₕ = -b.c.ρuₕ
elseif :uₕ in propertynames(x.c)
@. x.c.uₕ = -b.c.uₕ
end

if A.transform
x .*= dtγ
end
end
8 changes: 4 additions & 4 deletions test/Operators/finitedifference/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space)
=#
face_space = Spaces.FaceFiniteDifferenceSpace(center_space)

function _linsolve!(x, A, b, update_matrix = false; kwargs...)
function test_linsolve!(x, A, b, update_matrix = false; kwargs...)

FT = Spaces.undertype(axes(x.c))

Expand Down Expand Up @@ -88,11 +88,11 @@ W = SchurComplementW(Y, use_transform, jacobi_flags)

using JET
using Test
@time _linsolve!(Y, W, b)
@time _linsolve!(Y, W, b)
@time test_linsolve!(Y, W, b)
@time test_linsolve!(Y, W, b)

@testset "JET test for `apply` in linsolve! kernel" begin
@test_opt _linsolve!(Y, W, b)
@test_opt test_linsolve!(Y, W, b)
end

ClimaCore.Operators.allow_mismatched_fd_spaces() = false

0 comments on commit 22534ef

Please sign in to comment.