diff --git a/docs/src/internals.md b/docs/src/internals.md index 67ad8314..160cf6e3 100644 --- a/docs/src/internals.md +++ b/docs/src/internals.md @@ -79,8 +79,17 @@ because it creates a new integrator obtained by copying all the fields of the old one and adding the diagnostics (with [`Accessors`](https://github.com/JuliaObjects/Accessors.jl)). +The `DiagnosticsHandler` also contains three `BitVectors`: `active_compute`, +`active_output`, `active_sync`. These `BitVectors` have the same length as the +number of scheduled diagnostics and signal whether something should done at a +given step. The `BitVectors` are defined and preallocated trying to reduce the +inference allocations that result from operations like `filter` on lists of +`ScheduledDiagnostics`. They are updated by a callback that is run before +`orchestrate_diagnostics` and they can be used to determine if +`orchestrate_diagnostics` should be run at all. + ## Orchestrate diagnostics One of the design goals for `orchestrate_diagnostics` is to keep all the broadcasted expression in the same function scope. This opens a path to optimize -the number of GPU kernel launches. +the number of GPU kernel launches. diff --git a/src/clima_diagnostics.jl b/src/clima_diagnostics.jl index ee32912e..afdbcee7 100644 --- a/src/clima_diagnostics.jl +++ b/src/clima_diagnostics.jl @@ -34,6 +34,12 @@ struct DiagnosticsHandler{ many times the given diagnostics was computed from the last time it was output to disk.""" counters::COUNT + + """Bitvectors that identify which diagnostics are active at the given step. This is here + mostly to reduce inference allocations that would result from operations like filter.""" + active_compute::BitVector + active_output::BitVector + active_sync::BitVector end """ @@ -105,23 +111,22 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing) end end + num_diagnostics = length(scheduled_diagnostics) + active_compute = BitVector(ntuple(_ -> false, num_diagnostics)) + active_output = BitVector(ntuple(_ -> false, num_diagnostics)) + active_sync = BitVector(ntuple(_ -> false, num_diagnostics)) + return DiagnosticsHandler( Tuple(scheduled_diagnostics), storage, accumulators, counters, + active_compute, + active_output, + active_sync, ) end -# Does the writer associated to `diag` need to be synced? -# It does only when it has a sync_schedule that is a callable and that -# callable returns true when called on the integrator -function _needs_sync(diag, integrator) - hasproperty(diag.output_writer, :sync_schedule) || return false - isnothing(diag.output_writer.sync_schedule) && return false - return diag.output_writer.sync_schedule(integrator) -end - """ orchestrate_diagnostics(integrator, diagnostic_handler::DiagnosticsHandler) @@ -133,15 +138,10 @@ function orchestrate_diagnostics( diagnostic_handler::DiagnosticsHandler, ) scheduled_diagnostics = diagnostic_handler.scheduled_diagnostics - active_compute = Bool[] - active_output = Bool[] - active_sync = Bool[] - for diag in scheduled_diagnostics - push!(active_compute, diag.compute_schedule_func(integrator)) - push!(active_output, diag.output_schedule_func(integrator)) - push!(active_sync, _needs_sync(diag, integrator)) - end + active_compute = diagnostic_handler.active_compute + active_output = diagnostic_handler.active_output + active_sync = diagnostic_handler.active_sync # Compute for diag_index in 1:length(scheduled_diagnostics) @@ -238,22 +238,92 @@ function orchestrate_diagnostics( return nothing end +# Does the writer associated to `diag` need to be synced? +# It does only when it has a sync_schedule that is a callable and that +# callable returns true when called on the integrator +function _needs_sync(diag, integrator) + hasproperty(diag.output_writer, :sync_schedule) || return false + isnothing(diag.output_writer.sync_schedule) && return false + return diag.output_writer.sync_schedule(integrator) +end + +""" + update_diagnostic_handler_bitvectors(integrator, diagnostics_handler::DiagnosticsHandler) + +Update the `active_{compute, update, sync}` bitvector in `diagnostics_handler`. + +The `diagnostics_handler` contains three bitvectors that determine which actions should be +taken at the current iterations. They are preallocated mostly to avoid inference allocations +that result from operations with groups of `ScheduledDiagnostics`, but they can also be used +to determine if `orchestrate_diagnostics` should be called or not. + +This function evaluates the various `schedule_func`s for all the `ScheduledDiagnostics` in +the `diagnostics_handler` and updates the bitvectors. This function should be called at +every step. +""" +function update_diagnostic_handler_bitvectors(integrator, diagnostics_handler) + scheduled_diagnostics = diagnostics_handler.scheduled_diagnostics + + for index in 1:length(scheduled_diagnostics) + diag = scheduled_diagnostics[index] + diagnostics_handler.active_compute[index] = + diag.compute_schedule_func(integrator) + diagnostics_handler.active_output[index] = + diag.output_schedule_func(integrator) + diagnostics_handler.active_sync[index] = _needs_sync(diag, integrator) + end + + return nothing +end + +""" + check_callback_condition(integrator, diagnostics_handler::DiagnosticsHandler) + +Return true when `orchestrate_diagnostics` should be called. +""" +function check_callback_condition(integrator, diagnostics_handler) + return any(diagnostics_handler.active_compute) || + any(diagnostics_handler.active_output) || + any(diagnostics_handler.active_sync) +end + """ - DiagnosticsCallback(diagnostics_handler::DiagnosticsHandler) + DiagnosticsCallbacks(diagnostics_handler::DiagnosticsHandler) + +Translate a `DiagnosticsHandler` into two SciML callbacks ready to be used. + +The first updates internal counters in `diagnostics_handler` that check if the diagnostics +have to be computed. The second actually computes and outputs the diagnostics. -Translate a `DiagnosticsHandler` into a SciML callback ready to be used. """ -function DiagnosticsCallback(diagnostics_handler::DiagnosticsHandler) - sciml_callback(integrator) = +function DiagnosticsCallbacks(diagnostics_handler::DiagnosticsHandler) + + # We use trivial condition to update the condition to run orchestrate_diagnostics + trivial_condition = (_, _, _) -> true + + sciml_callback_update_diagnostic_handler_bitvectors(integrator) = + update_diagnostic_handler_bitvectors(integrator, diagnostics_handler) + + orchestrate_condition = + (_, _, integrator) -> + check_callback_condition(integrator, diagnostics_handler) + + sciml_callback_orchestrate_diagnostics(integrator) = orchestrate_diagnostics(integrator, diagnostics_handler) - # SciMLBase.DiscreteCallback checks if the given condition is true at the end of each - # step. So, we set a condition that is always true, the callback is called at the end of - # every step. This callback runs `orchestrate_callbacks`, which manages which - # diagnostics functions to call - condition = (_, _, _) -> true + continuous_callbacks = () + discrete_callbacks = ( + SciMLBase.DiscreteCallback( + trivial_condition, + sciml_callback_update_diagnostic_handler_bitvectors, + ), + SciMLBase.DiscreteCallback( + orchestrate_condition, + sciml_callback_orchestrate_diagnostics, + ), + ) - return SciMLBase.DiscreteCallback(condition, sciml_callback) + return SciMLBase.CallbackSet(continuous_callbacks, discrete_callbacks) end """ @@ -263,7 +333,7 @@ end Return a new `integrator` with diagnostics defined by `scheduled_diagnostics`. `IntegratorWithDiagnostics` is conceptually similar to defining a `DiagnosticsHandler`, -constructing its associated `DiagnosticsCallback`, and adding such callback to a given +constructing its associated `DiagnosticsCallbacks`, and adding such callbacks to a given integrator. The new integrator is identical to the previous one with the only difference that it has a @@ -284,11 +354,12 @@ function IntegratorWithDiagnostics(integrator, scheduled_diagnostics) integrator.t; integrator.dt, ) - diagnostics_callback = DiagnosticsCallback(diagnostics_handler) + diagnostics_callbacks = + DiagnosticsCallbacks(diagnostics_handler).discrete_callbacks continuous_callbacks = integrator.callback.continuous_callbacks discrete_callbacks = - (integrator.callback.discrete_callbacks..., diagnostics_callback) + (integrator.callback.discrete_callbacks..., diagnostics_callbacks...) callback = SciMLBase.CallbackSet(continuous_callbacks, discrete_callbacks) Accessors.@reset integrator.callback = callback diff --git a/test/diagnostics.jl b/test/diagnostics.jl index 982a3921..d33564a9 100644 --- a/test/diagnostics.jl +++ b/test/diagnostics.jl @@ -63,7 +63,7 @@ include("TestTools.jl") dt, ) - diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler) + diag_cbs = ClimaDiagnostics.DiagnosticsCallbacks(diagnostic_handler) prob = SciMLBase.ODEProblem( ClimaTimeSteppers.ClimaODEFunction(T_exp! = exp_tendency!), @@ -71,9 +71,20 @@ include("TestTools.jl") (t0, tf), p, ) + + @test ClimaDiagnostics.check_callback_condition( + prob, + diagnostic_handler, + ) === false + algo = ClimaTimeSteppers.ExplicitAlgorithm(ClimaTimeSteppers.RK4()) - SciMLBase.solve(prob, algo, dt = dt, callback = diag_cb) + SciMLBase.solve(prob, algo, dt = dt, callback = diag_cbs) + + @test ClimaDiagnostics.check_callback_condition( + prob, + diagnostic_handler, + ) === true @test length(keys(dict_writer.dict[short_name])) == convert(Int, 1 + (tf - t0) / dt) @@ -100,7 +111,7 @@ include("TestTools.jl") dt, ) - diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler) + diag_cbs = ClimaDiagnostics.DiagnosticsCallbacks(diagnostic_handler) prob = SciMLBase.ODEProblem( ClimaTimeSteppers.ClimaODEFunction(T_exp! = exp_tendency!), @@ -110,7 +121,7 @@ include("TestTools.jl") ) algo = ClimaTimeSteppers.ExplicitAlgorithm(ClimaTimeSteppers.RK4()) - SciMLBase.solve(prob, algo, dt = dt, callback = diag_cb) + SciMLBase.solve(prob, algo, dt = dt, callback = diag_cbs) @test length(keys(dict_writer.dict[short_name])) == convert(Int, (tf - t0) / 5dt)