Skip to content

Commit

Permalink
Change post to pre hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Nov 14, 2024
1 parent 0c57a2f commit ffdb816
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 116 deletions.
18 changes: 9 additions & 9 deletions ext/ClimaTimeSteppersBenchmarkToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict(
"T_exp_T_lim!" => 4,
"lim!" => 4,
"dss!" => 4,
"post_explicit!" => 3,
"post_implicit!" => 4,
"pre_explicit!" => 3,
"pre_implicit!" => 4,
"step!" => 1,
)
function n_calls_per_step(alg::CTS.RosenbrockAlgorithm)
Expand All @@ -47,8 +47,8 @@ function n_calls_per_step(alg::CTS.RosenbrockAlgorithm)
"T_exp_T_lim!" => CTS.n_stages(alg.tableau),
"lim!" => 0,
"dss!" => CTS.n_stages(alg.tableau),
"post_explicit!" => 0,
"post_implicit!" => CTS.n_stages(alg.tableau),
"pre_explicit!" => 0,
"pre_implicit!" => CTS.n_stages(alg.tableau),
"step!" => 1,
)
end
Expand All @@ -60,7 +60,7 @@ function maybe_push!(trials₀, name, f!, args, kwargs, only)
end

const allowed_names =
["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "post_explicit!", "post_implicit!", "step!"]
["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "pre_explicit!", "pre_implicit!", "step!"]

"""
benchmark_step(
Expand Down Expand Up @@ -89,8 +89,8 @@ Benchmark a DistributedODEIntegrator given:
- "T_exp_T_lim!"
- "lim!"
- "dss!"
- "post_explicit!"
- "post_implicit!"
- "pre_explicit!"
- "pre_implicit!"
- "step!"
"""
function CTS.benchmark_step(
Expand Down Expand Up @@ -123,8 +123,8 @@ function CTS.benchmark_step(
maybe_push!(trials₀, "T_exp_T_lim!", remaining_fun(integrator), remaining_args(integrator), kwargs, only)
maybe_push!(trials₀, "lim!", f.lim!, (Xlim, p, t, u), kwargs, only)
maybe_push!(trials₀, "dss!", f.dss!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "post_explicit!", f.post_explicit!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "post_implicit!", f.post_implicit!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "pre_explicit!", f.pre_explicit!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "pre_implicit!", f.pre_implicit!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "step!", SciMLBase.step!, (integrator, ), kwargs, only)
#! format: on

Expand Down
10 changes: 5 additions & 5 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFuncti
T_imp!::TI
lim!::L
dss!::D
post_explicit!::PE
post_implicit!::PI
pre_explicit!::PE
pre_implicit!::PI
function ClimaODEFunction(;
T_exp_T_lim! = nothing, # nothing or (uₜ_exp, uₜ_lim, u, p, t) -> ...
T_lim! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_exp! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_imp! = nothing, # nothing or (uₜ, u, p, t) -> ...
lim! = (u, p, t, u_ref) -> nothing,
dss! = (u, p, t) -> nothing,
post_explicit! = (u, p, t) -> nothing,
post_implicit! = (u, p, t) -> nothing,
pre_explicit! = (u, p, t) -> nothing,
pre_implicit! = (u, p, t) -> nothing,
)
args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!)
args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, pre_explicit!, pre_implicit!)

if !isnothing(T_exp_T_lim!)
@assert isnothing(T_exp!) "`T_exp_T_lim!` was passed, `T_exp!` must be `nothing`"
Expand Down
4 changes: 2 additions & 2 deletions src/integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ function DiffEqBase.__init(
tdir,
)
if prob.f isa ClimaODEFunction
(; post_explicit!) = prob.f
isnothing(post_explicit!) || post_explicit!(u0, p, t0)
(; pre_explicit!) = prob.f
isnothing(pre_explicit!) || pre_explicit!(u0, p, t0)
end
DiffEqBase.initialize!(callback, u0, t0, integrator)
return integrator
Expand Down
51 changes: 18 additions & 33 deletions src/nl_solvers/newtons_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ struct ForwardDiffStepSize3 <: ForwardDiffStepSize end
Computes the Jacobian-vector product `j(x[n]) * Δx[n]` for a Newton-Krylov
method without directly using the Jacobian `j(x[n])`, and instead only using
`x[n]`, `f(x[n])`, and other function evaluations `f(x′)`. This is done by
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)`.
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, pre_implicit!)`.
The `jΔx` passed to a Jacobian-free JVP is modified in-place. The `cache` can
be obtained with `allocate_cache(::JacobianFreeJVP, x_prototype)`, where
`x_prototype` is `similar` to `x` (and also to `Δx` and `f`).
Expand All @@ -151,13 +151,13 @@ end

allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = similar(x_prototype), f2 = similar(x_prototype))

function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)
function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, pre_implicit!)
(; default_step, step_adjustment) = alg
(; x2, f2) = cache
FT = eltype(x)
ε = FT(step_adjustment) * default_step(Δx, x)
@. x2 = x + ε * Δx
isnothing(post_implicit!) || post_implicit!(x2)
isnothing(pre_implicit!) || pre_implicit!(x2)
f!(f2, x2)
@. jΔx = (f2 - f) / ε
end
Expand Down Expand Up @@ -343,7 +343,7 @@ end
Finds an approximation `Δx[n] ≈ j(x[n]) \\ f(x[n])` for Newton's method such
that `‖f(x[n]) - j(x[n]) * Δx[n]‖ ≤ rtol[n] * ‖f(x[n])‖`, where `rtol[n]` is the
value of the forcing term on iteration `n`. This is done by calling
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)`,
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, pre_implicit!, j = nothing)`,
where `f` is `f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an
approximation of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place.
The `cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`,
Expand Down Expand Up @@ -428,14 +428,14 @@ function allocate_cache(alg::KrylovMethod, x_prototype)
)
end

NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)
NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, pre_implicit!, j = nothing)
(; jacobian_free_jvp, forcing_term, solve_kwargs) = alg
(; disable_preconditioner, debugger) = alg
type = solver_type(alg)
(; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) = cache
jΔx!(jΔx, Δx) =
isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) :
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, post_implicit!)
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, pre_implicit!)
opj = LinearOperator(eltype(x), length(x), length(x), false, false, jΔx!)
M = disable_preconditioner || isnothing(j) || isnothing(jacobian_free_jvp) ? I : j
print_debug!(debugger, debugger_cache, opj, M)
Expand Down Expand Up @@ -567,32 +567,22 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing)
)
end

solve_newton!(
alg::NewtonsMethod,
cache::Nothing,
x,
f!,
j! = nothing,
post_implicit! = nothing,
post_implicit_last! = nothing,
) = nothing

NVTX.@annotate function solve_newton!(
alg::NewtonsMethod,
cache,
x,
f!,
j! = nothing,
post_implicit! = nothing,
post_implicit_last! = nothing,
)
solve_newton!(alg::NewtonsMethod, cache::Nothing, x, f!, j! = nothing, pre_implicit! = nothing) = nothing

NVTX.@annotate function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing, pre_implicit! = nothing)
(; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg
(; krylov_method_cache, convergence_checker_cache) = cache
(; Δx, f, j) = cache
if (!isnothing(j)) && needs_update!(update_j, NewNewtonSolve())
j!(j, x)
if !isnothing(pre_implicit!) && !isempty(1:max_iters)
pre_implicit!(x)
if (!isnothing(j)) && needs_update!(update_j, NewNewtonSolve())
j!(j, x)
end
end
for n in 1:max_iters
if !isnothing(pre_implicit!)
n 1 && pre_implicit!(x)
end
# Compute Δx[n].
if (!isnothing(j)) && needs_update!(update_j, NewNewtonIteration())
j!(j, x)
Expand All @@ -605,20 +595,15 @@ NVTX.@annotate function solve_newton!(
ldiv!(Δx, j, f)
end
else
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, post_implicit!, j)
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, pre_implicit!, j)
end
is_verbose(verbose) && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"

x .-= Δx
# Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed.
# Check for convergence if necessary.
if is_converged!(convergence_checker, convergence_checker_cache, x, Δx, n)
isnothing(post_implicit_last!) || post_implicit_last!(x)
break
elseif n == max_iters
isnothing(post_implicit_last!) || post_implicit_last!(x)
else
isnothing(post_implicit!) || post_implicit!(x)
end
if is_verbose(verbose) && n == max_iters
@warn "Newton's method did not converge within $n iterations: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"
Expand Down
28 changes: 14 additions & 14 deletions src/solvers/hard_coded_ars343.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
(; u, p, t, dt, sol, alg) = integrator
(; f) = sol.prob
(; T_imp!, lim!, dss!) = f
(; post_explicit!, post_implicit!) = f
(; pre_explicit!, pre_implicit!) = f
(; tableau, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
(; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache
Expand Down Expand Up @@ -34,7 +34,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
lim!(U, p, t_exp, u)
@. U += dt * a_exp[i, 1] * T_exp[1]
dss!(U, p, t_exp)
post_explicit!(U, p, t_exp)

@. temp = U # used in closures
let i = i
Expand All @@ -46,21 +45,22 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
implicit_equation_jacobian! = (jacobian, Ui) -> begin
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
end
call_post_implicit! = Ui -> begin
post_implicit!(Ui, p, t_imp)
call_pre_implicit! = Ui -> begin
pre_implicit!(Ui, p, t_imp)
end
solve_newton!(
newtons_method,
newtons_method_cache,
U,
implicit_equation_residual!,
implicit_equation_jacobian!,
call_post_implicit!,
call_pre_implicit!,
)
end

@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])

pre_explicit!(U, p, t_exp)
T_lim!(T_lim[i], U, p, t_exp)
T_exp!(T_exp[i], U, p, t_exp)

Expand All @@ -70,7 +70,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
lim!(U, p, t_exp, u)
@. U += dt * a_exp[i, 1] * T_exp[1] + dt * a_exp[i, 2] * T_exp[2] + dt * a_imp[i, 2] * T_imp[2]
dss!(U, p, t_exp)
post_explicit!(U, p, t_exp)

@. temp = U # used in closures
let i = i
Expand All @@ -82,21 +81,22 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
implicit_equation_jacobian! = (jacobian, Ui) -> begin
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
end
call_post_implicit! = Ui -> begin
post_implicit!(Ui, p, t_imp)
call_pre_implicit! = Ui -> begin
pre_implicit!(Ui, p, t_imp)
end
solve_newton!(
newtons_method,
newtons_method_cache,
U,
implicit_equation_residual!,
implicit_equation_jacobian!,
call_post_implicit!,
call_pre_implicit!,
)
end

@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])

pre_explicit!(U, p, t_exp)
T_lim!(T_lim[i], U, p, t_exp)
T_exp!(T_exp[i], U, p, t_exp)
i = 4
Expand All @@ -110,7 +110,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
dt * a_imp[i, 2] * T_imp[2] +
dt * a_imp[i, 3] * T_imp[3]
dss!(U, p, t_exp)
post_explicit!(U, p, t_exp)

@. temp = U # used in closures
let i = i
Expand All @@ -122,21 +121,22 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
implicit_equation_jacobian! = (jacobian, Ui) -> begin
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
end
call_post_implicit! = Ui -> begin
post_implicit!(Ui, p, t_imp)
call_pre_implicit! = Ui -> begin
pre_implicit!(Ui, p, t_imp)
end
solve_newton!(
newtons_method,
newtons_method_cache,
U,
implicit_equation_residual!,
implicit_equation_jacobian!,
call_post_implicit!,
call_pre_implicit!,
)
end

@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])

pre_explicit!(U, p, t_exp)
T_lim!(T_lim[i], U, p, t_exp)
T_exp!(T_exp[i], U, p, t_exp)

Expand All @@ -155,6 +155,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
dt * b_imp[3] * T_imp[3] +
dt * b_imp[4] * T_imp[4]
dss!(u, p, t_final)
post_explicit!(u, p, t_final)
pre_explicit!(U, p, t_final)
return u
end
Loading

0 comments on commit ffdb816

Please sign in to comment.