Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add actual condition to run orchestrate_diagnostics #48

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add actual condition to run orchestrate_diagnostics
Sbozzolo committed May 19, 2024
commit c925edc644d1a1c2d1923360d67a4e94953b2dd9
11 changes: 10 additions & 1 deletion docs/src/internals.md
Original file line number Diff line number Diff line change
@@ -79,8 +79,17 @@ because it creates a new integrator obtained by copying all the fields of the
old one and adding the diagnostics (with
[`Accessors`](https://github.com/JuliaObjects/Accessors.jl)).

The `DiagnosticsHandler` also contains three `BitVectors`: `active_compute`,
`active_output`, `active_sync`. These `BitVectors` have the same length as the
number of scheduled diagnostics and signal whether something should done at a
given step. The `BitVectors` are defined and preallocated trying to reduce the
inference allocations that result from operations like `filter` on lists of
`ScheduledDiagnostics`. They are updated by a callback that is run before
`orchestrate_diagnostics` and they can be used to determine if
`orchestrate_diagnostics` should be run at all.

## Orchestrate diagnostics

One of the design goals for `orchestrate_diagnostics` is to keep all the
broadcasted expression in the same function scope. This opens a path to optimize
the number of GPU kernel launches.
the number of GPU kernel launches.
131 changes: 101 additions & 30 deletions src/clima_diagnostics.jl
Original file line number Diff line number Diff line change
@@ -34,6 +34,12 @@ struct DiagnosticsHandler{
many times the given diagnostics was computed from the last time it was output to
disk."""
counters::COUNT

"""Bitvectors that identify which diagnostics are active at the given step. This is here
mostly to reduce inference allocations that would result from operations like filter."""
active_compute::BitVector
active_output::BitVector
active_sync::BitVector
end

"""
@@ -105,23 +111,22 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
end
end

num_diagnostics = length(scheduled_diagnostics)
active_compute = BitVector(ntuple(_ -> false, num_diagnostics))
active_output = BitVector(ntuple(_ -> false, num_diagnostics))
active_sync = BitVector(ntuple(_ -> false, num_diagnostics))

return DiagnosticsHandler(
Tuple(scheduled_diagnostics),
storage,
accumulators,
counters,
active_compute,
active_output,
active_sync,
)
end

# Does the writer associated to `diag` need to be synced?
# It does only when it has a sync_schedule that is a callable and that
# callable returns true when called on the integrator
function _needs_sync(diag, integrator)
hasproperty(diag.output_writer, :sync_schedule) || return false
isnothing(diag.output_writer.sync_schedule) && return false
return diag.output_writer.sync_schedule(integrator)
end

"""
orchestrate_diagnostics(integrator, diagnostic_handler::DiagnosticsHandler)
@@ -133,15 +138,10 @@ function orchestrate_diagnostics(
diagnostic_handler::DiagnosticsHandler,
)
scheduled_diagnostics = diagnostic_handler.scheduled_diagnostics
active_compute = Bool[]
active_output = Bool[]
active_sync = Bool[]

for diag in scheduled_diagnostics
push!(active_compute, diag.compute_schedule_func(integrator))
push!(active_output, diag.output_schedule_func(integrator))
push!(active_sync, _needs_sync(diag, integrator))
end
active_compute = diagnostic_handler.active_compute
active_output = diagnostic_handler.active_output
active_sync = diagnostic_handler.active_sync

# Compute
for diag_index in 1:length(scheduled_diagnostics)
@@ -238,22 +238,92 @@ function orchestrate_diagnostics(
return nothing
end

# Does the writer associated to `diag` need to be synced?
# It does only when it has a sync_schedule that is a callable and that
# callable returns true when called on the integrator
function _needs_sync(diag, integrator)
hasproperty(diag.output_writer, :sync_schedule) || return false
isnothing(diag.output_writer.sync_schedule) && return false
return diag.output_writer.sync_schedule(integrator)
end

"""
update_diagnostic_handler_bitvectors(integrator, diagnostics_handler::DiagnosticsHandler)
Update the `active_{compute, update, sync}` bitvector in `diagnostics_handler`.
The `diagnostics_handler` contains three bitvectors that determine which actions should be
taken at the current iterations. They are preallocated mostly to avoid inference allocations
that result from operations with groups of `ScheduledDiagnostics`, but they can also be used
to determine if `orchestrate_diagnostics` should be called or not.
This function evaluates the various `schedule_func`s for all the `ScheduledDiagnostics` in
the `diagnostics_handler` and updates the bitvectors. This function should be called at
every step.
"""
function update_diagnostic_handler_bitvectors(integrator, diagnostics_handler)
scheduled_diagnostics = diagnostics_handler.scheduled_diagnostics

for index in 1:length(scheduled_diagnostics)
diag = scheduled_diagnostics[index]
diagnostics_handler.active_compute[index] =
diag.compute_schedule_func(integrator)
diagnostics_handler.active_output[index] =
diag.output_schedule_func(integrator)
diagnostics_handler.active_sync[index] = _needs_sync(diag, integrator)
end

return nothing
end

"""
check_callback_condition(integrator, diagnostics_handler::DiagnosticsHandler)
Return true when `orchestrate_diagnostics` should be called.
"""
function check_callback_condition(integrator, diagnostics_handler)
return any(diagnostics_handler.active_compute) ||
any(diagnostics_handler.active_output) ||
any(diagnostics_handler.active_sync)
end

"""
DiagnosticsCallback(diagnostics_handler::DiagnosticsHandler)
DiagnosticsCallbacks(diagnostics_handler::DiagnosticsHandler)
Translate a `DiagnosticsHandler` into two SciML callbacks ready to be used.
The first updates internal counters in `diagnostics_handler` that check if the diagnostics
have to be computed. The second actually computes and outputs the diagnostics.
Translate a `DiagnosticsHandler` into a SciML callback ready to be used.
"""
function DiagnosticsCallback(diagnostics_handler::DiagnosticsHandler)
sciml_callback(integrator) =
function DiagnosticsCallbacks(diagnostics_handler::DiagnosticsHandler)

# We use trivial condition to update the condition to run orchestrate_diagnostics
trivial_condition = (_, _, _) -> true

sciml_callback_update_diagnostic_handler_bitvectors(integrator) =
update_diagnostic_handler_bitvectors(integrator, diagnostics_handler)

orchestrate_condition =
(_, _, integrator) ->
check_callback_condition(integrator, diagnostics_handler)

sciml_callback_orchestrate_diagnostics(integrator) =
orchestrate_diagnostics(integrator, diagnostics_handler)

# SciMLBase.DiscreteCallback checks if the given condition is true at the end of each
# step. So, we set a condition that is always true, the callback is called at the end of
# every step. This callback runs `orchestrate_callbacks`, which manages which
# diagnostics functions to call
condition = (_, _, _) -> true
continuous_callbacks = ()
discrete_callbacks = (
SciMLBase.DiscreteCallback(
trivial_condition,
sciml_callback_update_diagnostic_handler_bitvectors,
),
SciMLBase.DiscreteCallback(
orchestrate_condition,
sciml_callback_orchestrate_diagnostics,
),
)

return SciMLBase.DiscreteCallback(condition, sciml_callback)
return SciMLBase.CallbackSet(continuous_callbacks, discrete_callbacks)
end

"""
@@ -263,7 +333,7 @@ end
Return a new `integrator` with diagnostics defined by `scheduled_diagnostics`.
`IntegratorWithDiagnostics` is conceptually similar to defining a `DiagnosticsHandler`,
constructing its associated `DiagnosticsCallback`, and adding such callback to a given
constructing its associated `DiagnosticsCallbacks`, and adding such callbacks to a given
integrator.
The new integrator is identical to the previous one with the only difference that it has a
@@ -284,11 +354,12 @@ function IntegratorWithDiagnostics(integrator, scheduled_diagnostics)
integrator.t;
integrator.dt,
)
diagnostics_callback = DiagnosticsCallback(diagnostics_handler)
diagnostics_callbacks =
DiagnosticsCallbacks(diagnostics_handler).discrete_callbacks

continuous_callbacks = integrator.callback.continuous_callbacks
discrete_callbacks =
(integrator.callback.discrete_callbacks..., diagnostics_callback)
(integrator.callback.discrete_callbacks..., diagnostics_callbacks...)
callback = SciMLBase.CallbackSet(continuous_callbacks, discrete_callbacks)

Accessors.@reset integrator.callback = callback
19 changes: 15 additions & 4 deletions test/diagnostics.jl
Original file line number Diff line number Diff line change
@@ -63,17 +63,28 @@ include("TestTools.jl")
dt,
)

diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
diag_cbs = ClimaDiagnostics.DiagnosticsCallbacks(diagnostic_handler)

prob = SciMLBase.ODEProblem(
ClimaTimeSteppers.ClimaODEFunction(T_exp! = exp_tendency!),
Y,
(t0, tf),
p,
)

@test ClimaDiagnostics.check_callback_condition(
prob,
diagnostic_handler,
) === false

algo = ClimaTimeSteppers.ExplicitAlgorithm(ClimaTimeSteppers.RK4())

SciMLBase.solve(prob, algo, dt = dt, callback = diag_cb)
SciMLBase.solve(prob, algo, dt = dt, callback = diag_cbs)

@test ClimaDiagnostics.check_callback_condition(
prob,
diagnostic_handler,
) === true

@test length(keys(dict_writer.dict[short_name])) ==
convert(Int, 1 + (tf - t0) / dt)
@@ -100,7 +111,7 @@ include("TestTools.jl")
dt,
)

diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
diag_cbs = ClimaDiagnostics.DiagnosticsCallbacks(diagnostic_handler)

prob = SciMLBase.ODEProblem(
ClimaTimeSteppers.ClimaODEFunction(T_exp! = exp_tendency!),
@@ -110,7 +121,7 @@ include("TestTools.jl")
)
algo = ClimaTimeSteppers.ExplicitAlgorithm(ClimaTimeSteppers.RK4())

SciMLBase.solve(prob, algo, dt = dt, callback = diag_cb)
SciMLBase.solve(prob, algo, dt = dt, callback = diag_cbs)

@test length(keys(dict_writer.dict[short_name])) ==
convert(Int, (tf - t0) / 5dt)