diff --git a/src/ClimaTimeSteppers.jl b/src/ClimaTimeSteppers.jl index 78a14a90..add59048 100644 --- a/src/ClimaTimeSteppers.jl +++ b/src/ClimaTimeSteppers.jl @@ -118,6 +118,7 @@ const SPCO = SparseCoeffs include("solvers/imex_tableaus.jl") include("solvers/explicit_tableaus.jl") +include("solvers/compute_T_exp_T_lim.jl") include("solvers/imex_ark.jl") include("solvers/imex_ssprk.jl") include("solvers/multirate.jl") diff --git a/src/functions.jl b/src/functions.jl index 4e80446c..e6fb2502 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -4,7 +4,7 @@ export ClimaODEFunction, ForwardEulerODEFunction abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end -Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction +Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI, CC} <: AbstractClimaODEFunction T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ... T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ... T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ... @@ -12,6 +12,7 @@ Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaOD dss!::D = (u, p, t) -> nothing post_explicit!::PE = (u, p, t) -> nothing post_implicit!::PI = (u, p, t) -> nothing + comms_context::CC = nothing end # Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work). diff --git a/src/solvers/compute_T_exp_T_lim.jl b/src/solvers/compute_T_exp_T_lim.jl new file mode 100644 index 00000000..d2758b14 --- /dev/null +++ b/src/solvers/compute_T_exp_T_lim.jl @@ -0,0 +1,13 @@ +@inline function compute_T_lim_T_exp!( + T_lim, + T_exp, + U, + p, + t, + T_lim!, + T_exp!, + ::Union{Nothing, ClimaComms.AbstractCommsContext}, +) + T_lim!(T_lim, U, p, t) + T_exp!(T_exp, U, p, t) +end diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl index 2bdf601d..29b4d5ae 100644 --- a/src/solvers/imex_ark.jl +++ b/src/solvers/imex_ark.jl @@ -169,8 +169,13 @@ end end if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i]) - isnothing(T_lim!) || T_lim!(T_lim[i], U, p, t_exp) - isnothing(T_exp!) || T_exp!(T_exp[i], U, p, t_exp) + if !isnothing(T_lim!) && !isnothing(T_exp!) + (; comms_context) = f + compute_T_lim_T_exp!(T_lim[i], T_exp[i], U, p, t_exp, T_lim!, T_exp!, comms_context) + else + isnothing(T_lim!) || T_lim!(T_lim[i], U, p, t_exp) + isnothing(T_exp!) || T_exp!(T_exp[i], U, p, t_exp) + end end return nothing diff --git a/src/solvers/imex_ssprk.jl b/src/solvers/imex_ssprk.jl index 05090db6..4cc5f069 100644 --- a/src/solvers/imex_ssprk.jl +++ b/src/solvers/imex_ssprk.jl @@ -153,11 +153,12 @@ function step_u!(integrator, cache::IMEXSSPRKCache) end if !iszero(β[i]) - if !isnothing(T_lim!) - T_lim!(T_lim, U, p, t_exp) - end - if !isnothing(T_exp!) - T_exp!(T_exp, U, p, t_exp) + if !isnothing(T_lim!) && !isnothing(T_exp!) + (; comms_context) = f + compute_T_lim_T_exp!(T_lim[i], T_exp[i], U, p, t_exp, T_lim!, T_exp!, comms_context) + else + isnothing(T_lim!) || T_lim!(T_lim, U, p, t_exp) + isnothing(T_exp!) || T_exp!(T_exp, U, p, t_exp) end end end