Skip to content

Commit

Permalink
Add support for unevaluated compute! functions
Browse files Browse the repository at this point in the history
`LazyBroadcast.jl` provides a way to return an unevaluated function.
This is useful in two cases:
1. reduce code verbosity to handle the `isnothing(out)` case
2. allow clustering all the broadcasted expressions in a single place

In turn, 2. is useful because it is the first step in fusing different
broadcasted calls.

This commit adds support for such functions.
  • Loading branch information
Sbozzolo committed Aug 14, 2024
1 parent 19005a0 commit 6e8e28c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 6 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ClimaCore = "0.13.4, 0.14"
ClimaTimeSteppers = "0.7.10"
Dates = "1"
Documenter = "1"
LazyBroadcast = "0.1.3"
JuliaFormatter = "1"
NCDatasets = "0.13.1, 0.14"
Profile = "1"
Expand All @@ -35,10 +36,11 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
ProfileCanvas = "efd6af41-a80b-495e-886c-e51b0c7d77a3"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "JuliaFormatter", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "JuliaFormatter", "LazyBroadcast", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]
54 changes: 50 additions & 4 deletions src/clima_diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct DiagnosticsHandler{
STORAGE <: Dict,
ACC <: Dict,
COUNT <: Dict,
BROAD <: Dict,
}
"""An iterable with the `ScheduledDiagnostic`s that are scheduled."""
scheduled_diagnostics::SD
Expand All @@ -34,6 +35,11 @@ struct DiagnosticsHandler{
many times the given diagnostics was computed from the last time it was output to
disk."""
counters::COUNT

"""Dictionary that maps a given `ScheduledDiagnostic` to a Base.Broadcast.Broadcasted
expression. This is used to allow lazy evaluation of expressions, which can lead to
reduce code verbosity and improved performance."""
broadcasted_expressions::BROAD
end

"""
Expand All @@ -57,9 +63,11 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)

# For diagnostics that perform reductions, the storage is used for the values computed
# at each call. Reductions also save the accumulated value in accumulators.
# broadcasted_expressions maps diagnostics with LazyBroadcast objects.
storage = Dict()
accumulators = Dict()
counters = Dict()
broadcasted_expressions = Dict()

unique_scheduled_diagnostics = Tuple(unique(scheduled_diagnostics))
if length(unique_scheduled_diagnostics) != length(scheduled_diagnostics)
Expand Down Expand Up @@ -90,8 +98,18 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
isa_time_reduction = !isnothing(diag.reduction_time_func)

# The first time we call compute! we use its return value. All the subsequent times
# (in the callbacks), we will write the result in place
storage[diag] = variable.compute!(nothing, Y, p, t)
# (in the callbacks), we will write the result in place. ClimaDiagnostics supports
# LazyBroadcast.jl. In this case, the return value of `compute!` is a
# `Base.Broadcast.Broadcasted` and we have to manually materialize the result.
out_or_broadcasted_expr = variable.compute!(nothing, Y, p, t)
if out_or_broadcasted_expr isa Base.Broadcast.Broadcasted
broadcasted_expressions[diag] = out_or_broadcasted_expr
storage[diag] =
Base.Broadcast.materialize(broadcasted_expressions[diag])
else
storage[diag] = out_or_broadcasted_expr
end

counters[diag] = 1

# If it is not a reduction, call the output writer as well
Expand All @@ -115,6 +133,7 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
storage,
accumulators,
counters,
broadcasted_expressions,
)
end

Expand Down Expand Up @@ -153,13 +172,40 @@ function orchestrate_diagnostics(
active_compute[diag_index] || continue
diag = scheduled_diagnostics[diag_index]

diag.variable.compute!(
diagnostic_handler.counters[diag] += 1

# ClimaDiagnostics supports LazyBroadcast.jl. When used, the return value of
# `compute!` is a `Base.Broadcast.Broadcasted`. We materialize the output to
# diagnostic_handler.storage[diag] in the next for loop.

out_or_broadcasted_expr = diag.variable.compute!(
diagnostic_handler.storage[diag],
integrator.u,
integrator.p,
integrator.t,
)
diagnostic_handler.counters[diag] += 1
if out_or_broadcasted_expr isa Base.Broadcast.Broadcasted
diagnostic_handler.broadcasted_expressions[diag] =
out_or_broadcasted_expr
end
end

# Evaluate the lazy compute (aka, materialize everything)
for diag_index in 1:length(scheduled_diagnostics)
active_compute[diag_index] || continue
diag = scheduled_diagnostics[diag_index]
haskey(diagnostic_handler.broadcasted_expressions, diag) || continue

Base.Broadcast.materialize!(
diagnostic_handler.storage[diag],
diagnostic_handler.broadcasted_expressions[diag],
)
end

# Process possible time reductions (now we have evaluated storage[diag])
for diag_index in 1:length(scheduled_diagnostics)
active_compute[diag_index] || continue
diag = scheduled_diagnostics[diag_index]

isa_time_reduction = !isnothing(diag.reduction_time_func)
if isa_time_reduction
Expand Down
23 changes: 22 additions & 1 deletion test/integration_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import ClimaComms
ClimaComms.@import_required_backends
end

import LazyBroadcast: @lazy

const context = ClimaComms.context()
ClimaComms.init(context)

Expand Down Expand Up @@ -41,7 +43,15 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
)

function compute_my_var!(out, u, p, t)
ClimaDiagnostics.@assign out copy(u.my_var)
if isnothing(out)
return copy(u.my_var)
else
out .= u.my_var
end
end

function compute_my_var_lazy!(out, u, p, t)
return @lazy @. out = copy(u.my_var)
end

simple_var = ClimaDiagnostics.DiagnosticVariable(;
Expand All @@ -50,6 +60,12 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
long_name = "YO YO",
)

simple_var_lazy = ClimaDiagnostics.DiagnosticVariable(;
compute! = compute_my_var_lazy!,
short_name = "YO LAZY",
long_name = "YO YO LAZY",
)

average_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(
variable = simple_var,
output_writer = nc_writer,
Expand All @@ -61,6 +77,10 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
variable = simple_var,
output_writer = nc_writer,
)
inst_diagnostic_lazy = ClimaDiagnostics.ScheduledDiagnostic(
variable = simple_var_lazy,
output_writer = nc_writer,
)
inst_every3s_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(
variable = simple_var,
output_writer = nc_writer,
Expand All @@ -76,6 +96,7 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
scheduled_diagnostics = [
average_diagnostic,
inst_diagnostic,
inst_diagnostic_lazy,
inst_diagnostic_h5,
inst_every3s_diagnostic,
]
Expand Down

0 comments on commit 6e8e28c

Please sign in to comment.