Skip to content

Commit

Permalink
Support asynchronous T_lim and T_exp
Browse files Browse the repository at this point in the history
Bump patch version
  • Loading branch information
charleskawczynski committed Feb 15, 2024
1 parent 4c9999a commit 471efd2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
3 changes: 2 additions & 1 deletion src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ export ClimaODEFunction, ForwardEulerODEFunction

abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end

Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction
Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI, CC} <: AbstractClimaODEFunction
T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ...
T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ...
T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ...
lim!::L = (u, p, t, u_ref) -> nothing
dss!::D = (u, p, t) -> nothing
post_explicit!::PE = (u, p, t) -> nothing
post_implicit!::PI = (u, p, t) -> nothing
comms_context::CC = nothing
end

# Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work).
Expand Down
21 changes: 16 additions & 5 deletions src/solvers/imex_ark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ function step_u!(integrator, cache::IMEXARKCache)
(; u, p, t, dt, alg) = integrator
(; f) = integrator.sol.prob
(; post_explicit!, post_implicit!) = f
(; comms_context) = f
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
(; tableau, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
Expand Down Expand Up @@ -147,11 +148,21 @@ function step_u!(integrator, cache::IMEXARKCache)
end

if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i])
if !isnothing(T_lim!)
T_lim!(T_lim[i], U, p, t_exp)
end
if !isnothing(T_exp!)
T_exp!(T_exp[i], U, p, t_exp)
if !isnothing(T_lim!) && !isnothing(T_lim!) && !isnothing(comms_context)
dev = ClimaComms.device(comms_context)
ClimaComms.@sync dev begin
@async begin
T_lim!(T_lim[i], U, p, t_exp)
nothing
end
@async begin
T_exp!(T_exp[i], U, p, t_exp)
nothing
end
end
else
isnothing(T_lim!) || T_lim!(T_lim[i], U, p, t_exp)
isnothing(T_exp!) || T_exp!(T_exp[i], U, p, t_exp)
end
end
end
Expand Down
21 changes: 16 additions & 5 deletions src/solvers/imex_ssprk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
(; u, p, t, dt, alg) = integrator
(; f) = integrator.sol.prob
(; post_explicit!, post_implicit!) = f
(; comms_context) = f
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
(; tableau, newtons_method) = alg
(; a_imp, b_imp, c_exp, c_imp) = tableau
Expand Down Expand Up @@ -153,11 +154,21 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
end

if !iszero(β[i])
if !isnothing(T_lim!)
T_lim!(T_lim, U, p, t_exp)
end
if !isnothing(T_exp!)
T_exp!(T_exp, U, p, t_exp)
if !isnothing(T_lim!) && !isnothing(T_lim!) && !isnothing(comms_context)
dev = ClimaComms.device(comms_context)
ClimaComms.@sync dev begin
@async begin
T_lim!(T_lim, U, p, t_exp)
nothing
end
@async begin
T_exp!(T_exp, U, p, t_exp)
nothing
end
end
else
isnothing(T_lim!) || T_lim!(T_lim, U, p, t_exp)
isnothing(T_exp!) || T_exp!(T_exp, U, p, t_exp)
end
end
end
Expand Down

0 comments on commit 471efd2

Please sign in to comment.