diff --git a/src/functions.jl b/src/functions.jl index 4e80446c2..e6fb2502e 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -4,7 +4,7 @@ export ClimaODEFunction, ForwardEulerODEFunction abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end -Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction +Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI, CC} <: AbstractClimaODEFunction T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ... T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ... T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ... @@ -12,6 +12,7 @@ Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaOD dss!::D = (u, p, t) -> nothing post_explicit!::PE = (u, p, t) -> nothing post_implicit!::PI = (u, p, t) -> nothing + comms_context::CC = nothing end # Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work). diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl index 2b3e27360..ffabd0eb0 100644 --- a/src/solvers/imex_ark.jl +++ b/src/solvers/imex_ark.jl @@ -50,6 +50,7 @@ function step_u!(integrator, cache::IMEXARKCache) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob (; post_explicit!, post_implicit!) = f + (; comms_context) = f (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau @@ -147,11 +148,21 @@ function step_u!(integrator, cache::IMEXARKCache) end if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i]) - if !isnothing(T_lim!) - T_lim!(T_lim[i], U, p, t_exp) - end - if !isnothing(T_exp!) - T_exp!(T_exp[i], U, p, t_exp) + if !isnothing(T_lim!) && !isnothing(T_lim!) && !isnothing(comms_context) + dev = ClimaComms.device(comms_context) + ClimaComms.@sync dev begin + @async begin + T_lim!(T_lim[i], U, p, t_exp) + nothing + end + @async begin + T_exp!(T_exp[i], U, p, t_exp) + nothing + end + end + else + isnothing(T_lim!) || T_lim!(T_lim[i], U, p, t_exp) + isnothing(T_exp!) || T_exp!(T_exp[i], U, p, t_exp) end end end diff --git a/src/solvers/imex_ssprk.jl b/src/solvers/imex_ssprk.jl index 646889ba0..6f870068f 100644 --- a/src/solvers/imex_ssprk.jl +++ b/src/solvers/imex_ssprk.jl @@ -56,6 +56,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob (; post_explicit!, post_implicit!) = f + (; comms_context) = f (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_imp, b_imp, c_exp, c_imp) = tableau @@ -153,11 +154,21 @@ function step_u!(integrator, cache::IMEXSSPRKCache) end if !iszero(β[i]) - if !isnothing(T_lim!) - T_lim!(T_lim, U, p, t_exp) - end - if !isnothing(T_exp!) - T_exp!(T_exp, U, p, t_exp) + if !isnothing(T_lim!) && !isnothing(T_lim!) && !isnothing(comms_context) + dev = ClimaComms.device(comms_context) + ClimaComms.@sync dev begin + @async begin + T_lim!(T_lim, U, p, t_exp) + nothing + end + @async begin + T_exp!(T_exp, U, p, t_exp) + nothing + end + end + else + isnothing(T_lim!) || T_lim!(T_lim, U, p, t_exp) + isnothing(T_exp!) || T_exp!(T_exp, U, p, t_exp) end end end