From 5b5e89339e77e8e75b993d4fafd28dc5133690e7 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Wed, 12 Feb 2025 15:52:38 -0800 Subject: [PATCH 01/19] Set a default Clock depending on the grid --- ext/OceananigansReactantExt/Architectures.jl | 8 +++++-- .../OceananigansReactantExt.jl | 1 + ext/OceananigansReactantExt/TimeSteppers.jl | 24 +++++++++++++++++++ .../hydrostatic_free_surface_model.jl | 2 +- .../nonhydrostatic_model.jl | 2 +- .../ShallowWaterModels/shallow_water_model.jl | 2 +- src/TimeSteppers/clock.jl | 4 ++++ 7 files changed, 38 insertions(+), 5 deletions(-) create mode 100644 ext/OceananigansReactantExt/TimeSteppers.jl diff --git a/ext/OceananigansReactantExt/Architectures.jl b/ext/OceananigansReactantExt/Architectures.jl index 0573b49cc4..7483d06b4e 100644 --- a/ext/OceananigansReactantExt/Architectures.jl +++ b/ext/OceananigansReactantExt/Architectures.jl @@ -3,12 +3,15 @@ module Architectures using Reactant using Oceananigans -import Oceananigans.Architectures: device, architecture, array_type, on_architecture, unified_array, ReactantState, device_copy_to! +import Oceananigans.Architectures: device, architecture, array_type, on_architecture +import Oceananigans.Architectures: unified_array, ReactantState, device_copy_to! const ReactantKernelAbstractionsExt = Base.get_extension( Reactant, :ReactantKernelAbstractionsExt ) + const ReactantBackend = ReactantKernelAbstractionsExt.ReactantBackend + device(::ReactantState) = ReactantBackend() architecture(::Reactant.AnyConcreteRArray) = ReactantState @@ -24,6 +27,7 @@ on_architecture(::ReactantState, a::SubArray{<:Any, <:Any, <:Array}) = ConcreteR unified_array(::ReactantState, a) = a -@inline device_copy_to!(dst::Reactant.AnyConcreteRArray, src::Reactant.AnyConcreteRArray; kw...) = Base.copyto!(dst, src) +@inline device_copy_to!(dst::Reactant.AnyConcreteRArray, src::Reactant.AnyConcreteRArray; kw...) = + Base.copyto!(dst, src) end # module diff --git a/ext/OceananigansReactantExt/OceananigansReactantExt.jl b/ext/OceananigansReactantExt/OceananigansReactantExt.jl index 5e5fcad4cd..7dfe963db9 100644 --- a/ext/OceananigansReactantExt/OceananigansReactantExt.jl +++ b/ext/OceananigansReactantExt/OceananigansReactantExt.jl @@ -5,6 +5,7 @@ using Reactant include("Architectures.jl") using .Architectures + # These are additional modules that may need to be Reactantified in the future: # # include("Utils.jl") diff --git a/ext/OceananigansReactantExt/TimeSteppers.jl b/ext/OceananigansReactantExt/TimeSteppers.jl new file mode 100644 index 0000000000..191ed291b3 --- /dev/null +++ b/ext/OceananigansReactantExt/TimeSteppers.jl @@ -0,0 +1,24 @@ +module TimeSteppers + +using Reactant +using Oceananigans + +import Oceananigans.Grids: AbstractGrid +import Oceananigans.TimeSteppers: Clock + +using OceananigansReactantExt: ReactantState + +const ReactantGrid{FT, TX, TY, TZ} = AbstractGrid{FT, TX, TY, TZ, <:ReactantState} where {FT, TX, TY, TZ} + +function Clock(grid::ReactantGrid) + FT = Float64 # may change in the future + t = ConcreteRNumber(zero(FT)) + iter = ConcreteRNumber(0) + stage = ConcreteRNumber(0) + last_Δt = ConcreteRNumber(zero(FT)) + last_stage_Δt = ConcreteRNumber(zero(FT)) + return Clock(; time=t, iteration=iter, stage, last_Δt, last_stage_Δt) +end + +end # module + diff --git a/src/Models/HydrostaticFreeSurfaceModels/hydrostatic_free_surface_model.jl b/src/Models/HydrostaticFreeSurfaceModels/hydrostatic_free_surface_model.jl index 0cd379b2a4..28f6f3e10b 100644 --- a/src/Models/HydrostaticFreeSurfaceModels/hydrostatic_free_surface_model.jl +++ b/src/Models/HydrostaticFreeSurfaceModels/hydrostatic_free_surface_model.jl @@ -108,7 +108,7 @@ Keyword arguments - `vertical_coordinate`: Rulesets that define the time-evolution of the grid (ZStar/ZCoordinate). Default: `ZCoordinate()`. """ function HydrostaticFreeSurfaceModel(; grid, - clock = Clock{Float64}(time = 0), + clock = Clock(grid), momentum_advection = VectorInvariant(), tracer_advection = Centered(), buoyancy = nothing, diff --git a/src/Models/NonhydrostaticModels/nonhydrostatic_model.jl b/src/Models/NonhydrostaticModels/nonhydrostatic_model.jl index b545fbe3b4..baa79e56bb 100644 --- a/src/Models/NonhydrostaticModels/nonhydrostatic_model.jl +++ b/src/Models/NonhydrostaticModels/nonhydrostatic_model.jl @@ -112,7 +112,7 @@ Keyword arguments - `auxiliary_fields`: `NamedTuple` of auxiliary fields. Default: `nothing` """ function NonhydrostaticModel(; grid, - clock = Clock{Float64}(time = 0), + clock = Clock(grid), advection = Centered(), buoyancy = nothing, coriolis = nothing, diff --git a/src/Models/ShallowWaterModels/shallow_water_model.jl b/src/Models/ShallowWaterModels/shallow_water_model.jl index 851545a955..6a839bc5ab 100644 --- a/src/Models/ShallowWaterModels/shallow_water_model.jl +++ b/src/Models/ShallowWaterModels/shallow_water_model.jl @@ -112,7 +112,7 @@ Keyword arguments function ShallowWaterModel(; grid, gravitational_acceleration, - clock = Clock{eltype(grid)}(time=0), + clock = Clock(grid), momentum_advection = UpwindBiased(order=5), tracer_advection = WENO(), mass_advection = WENO(), diff --git a/src/TimeSteppers/clock.jl b/src/TimeSteppers/clock.jl index 48d8e5172f..a0de1935c8 100644 --- a/src/TimeSteppers/clock.jl +++ b/src/TimeSteppers/clock.jl @@ -1,6 +1,7 @@ using Adapt using Dates: AbstractTime, DateTime, Nanosecond, Millisecond using Oceananigans.Utils: prettytime +using Oceananigans.Grids: AbstractGrid import Base: show import Oceananigans.Units: Time @@ -54,6 +55,9 @@ function Clock{TT}(; time, return Clock{TT, DT}(time, last_Δt, last_stage_Δt, iteration, stage) end +# helpful default +Clock(grid::AbstractGrid) = Clock{Float64}(time=0) + function Base.summary(clock::Clock) TT = typeof(clock.time) DT = typeof(clock.last_Δt) From c12d044b95d465460915157b6ecfbb4a40ac106e Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Wed, 12 Feb 2025 22:56:19 -0800 Subject: [PATCH 02/19] Test default clock --- test/test_reactant.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_reactant.jl b/test/test_reactant.jl index c19f86fa1e..3308dd0160 100644 --- a/test/test_reactant.jl +++ b/test/test_reactant.jl @@ -23,7 +23,10 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw) r_model = ModelType(; grid=r_grid, model_kw...) grid = GridType(CPU(); grid_kw...) - model = ModelType(; grid=grid, model_kw...) + model = ModelType(; grid, model_kw...) + + @test model.clock.time isa ConcreteRNumber + @test model.clock.iteration isa ConcreteRNumber ui = randn(size(model.velocities.u)...) vi = randn(size(model.velocities.v)...) From fb38e66843abb5762d77f29888a4b5b075c5246b Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sat, 15 Feb 2025 22:33:04 -0800 Subject: [PATCH 03/19] Fix up the implementation of the extension and clean up tests --- .../OceananigansReactantExt.jl | 2 + ext/OceananigansReactantExt/TimeSteppers.jl | 6 +-- src/TimeSteppers/clock.jl | 12 +++-- test/test_reactant.jl | 49 ++++++------------- 4 files changed, 28 insertions(+), 41 deletions(-) diff --git a/ext/OceananigansReactantExt/OceananigansReactantExt.jl b/ext/OceananigansReactantExt/OceananigansReactantExt.jl index 7dfe963db9..4a4be1f9de 100644 --- a/ext/OceananigansReactantExt/OceananigansReactantExt.jl +++ b/ext/OceananigansReactantExt/OceananigansReactantExt.jl @@ -5,6 +5,8 @@ using Reactant include("Architectures.jl") using .Architectures +include("TimeSteppers.jl") +using .TimeSteppers # These are additional modules that may need to be Reactantified in the future: # diff --git a/ext/OceananigansReactantExt/TimeSteppers.jl b/ext/OceananigansReactantExt/TimeSteppers.jl index 191ed291b3..ba960e4376 100644 --- a/ext/OceananigansReactantExt/TimeSteppers.jl +++ b/ext/OceananigansReactantExt/TimeSteppers.jl @@ -3,10 +3,10 @@ module TimeSteppers using Reactant using Oceananigans -import Oceananigans.Grids: AbstractGrid -import Oceananigans.TimeSteppers: Clock +using Oceananigans.Grids: AbstractGrid +using ..Architectures: ReactantState -using OceananigansReactantExt: ReactantState +import Oceananigans.TimeSteppers: Clock const ReactantGrid{FT, TX, TY, TZ} = AbstractGrid{FT, TX, TY, TZ, <:ReactantState} where {FT, TX, TY, TZ} diff --git a/src/TimeSteppers/clock.jl b/src/TimeSteppers/clock.jl index a0de1935c8..66887b2953 100644 --- a/src/TimeSteppers/clock.jl +++ b/src/TimeSteppers/clock.jl @@ -13,12 +13,12 @@ Keeps track of the current `time`, `last_Δt`, `iteration` number, and time-step The `stage` is updated only for multi-stage time-stepping methods. The `time::T` is either a number or a `DateTime` object. """ -mutable struct Clock{TT, DT} +mutable struct Clock{TT, DT, IT} time :: TT last_Δt :: DT last_stage_Δt :: DT - iteration :: Int - stage :: Int + iteration :: IT + stage :: IT end """ @@ -35,8 +35,9 @@ function Clock(; time, TT = typeof(time) DT = typeof(last_Δt) + IT = typeof(iteration) last_stage_Δt = convert(DT, last_Δt) - return Clock{TT, DT}(time, last_Δt, last_stage_Δt, iteration, stage) + return Clock{TT, DT, IT}(time, last_Δt, last_stage_Δt, iteration, stage) end # TODO: when supporting DateTime, this function will have to be extended @@ -51,8 +52,9 @@ function Clock{TT}(; time, DT = time_step_type(TT) last_Δt = convert(DT, last_Δt) last_stage_Δt = convert(DT, last_stage_Δt) + IT = typeof(iteration) - return Clock{TT, DT}(time, last_Δt, last_stage_Δt, iteration, stage) + return Clock{TT, DT, IT}(time, last_Δt, last_stage_Δt, iteration, stage) end # helpful default diff --git a/test/test_reactant.jl b/test/test_reactant.jl index 3308dd0160..f242b80533 100644 --- a/test/test_reactant.jl +++ b/test/test_reactant.jl @@ -22,11 +22,12 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw) r_grid = GridType(r_arch; grid_kw...) r_model = ModelType(; grid=r_grid, model_kw...) + # Basic test for the default Clock{ConcreteRNumber} + @test r_model.clock.time isa ConcreteRNumber + @test r_model.clock.iteration isa ConcreteRNumber + grid = GridType(CPU(); grid_kw...) model = ModelType(; grid, model_kw...) - - @test model.clock.time isa ConcreteRNumber - @test model.clock.iteration isa ConcreteRNumber ui = randn(size(model.velocities.u)...) vi = randn(size(model.velocities.v)...) @@ -58,12 +59,17 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw) simulation = Simulation(model; Δt, stop_iteration, verbose=false) run!(simulation) - # What we want to do with Reactant: + @test iteration(simulation) == stop_iteration + @test time(simulation) == 3Δt + + # Reactant time now: r_simulation = Simulation(r_model; Δt, stop_iteration, verbose=false) pop!(r_simulation.callbacks, :nan_checker) r_run! = @compile sync = true run!(r_simulation) r_run!(r_simulation) + @test iteration(r_simulation) == stop_iteration + @test time(r_simulation) == 3Δt # Some tests # Things ran normally: @@ -83,6 +89,12 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw) @test parent(v) ≈ parent(rv) @test parent(w) ≈ parent(rw) + # Running a few more time-steps works too: + r_simulation.stop_iteration += 2 + r_run!(r_simulation) + @test iteration(r_simulation) == 5 + @test time(r_simulation) == 5Δt + return nothing end @@ -174,32 +186,3 @@ end =# end -@testset "Reactanigans Clock{ConcreteRNumber} tests" begin - @info "Testing model time-stepping with Clock{ConcreteRNumber}..." - - # All of these may not need to be traced but this is paranoia. - FT = Float64 - t = ConcreteRNumber(zero(FT)) - iter = ConcreteRNumber(0) - stage = ConcreteRNumber(0) - last_Δt = ConcreteRNumber(zero(FT)) - last_stage_Δt = ConcreteRNumber(zero(FT)) - clock = Clock(; time=t, iteration=iter, stage, last_Δt, last_stage_Δt) - - grid = RectilinearGrid(ReactantState(); size=(10, 10, 10), halo=(3, 3, 3), extent=(10, 10, 10)) - free_surface = SplitExplicitFreeSurface(grid, substeps=10, gravitational_acceleration=1) - model = HydrostaticFreeSurfaceModel(; grid, clock, free_surface) - - Δt = 0.02 - simulation = Simulation(model; Δt, stop_iteration=3, verbose=false) - run!(simulation) - - @test iteration(simulation) == 3 - @test time(simulation) == 0.06 - - simulation.stop_iteration += 2 - run!(simulation) - @test iteration(simulation) == 5 - @test time(simulation) == 0.10 -end - From ba16173a0aa50c5520c113a8b545a9d0370d3752 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sat, 15 Feb 2025 23:17:20 -0800 Subject: [PATCH 04/19] Some major extensions --- .../OceananigansReactantExt.jl | 3 ++ ext/OceananigansReactantExt/Simulations.jl | 54 +++++++++++++++++++ ext/OceananigansReactantExt/TimeSteppers.jl | 6 +-- .../hydrostatic_free_surface_model.jl | 2 +- .../nonhydrostatic_model.jl | 2 +- .../ShallowWaterModels/shallow_water_model.jl | 2 +- src/Oceananigans.jl | 2 +- src/Simulations/run.jl | 3 +- src/Utils/schedules.jl | 1 - test/test_reactant.jl | 1 - 10 files changed, 66 insertions(+), 10 deletions(-) create mode 100644 ext/OceananigansReactantExt/Simulations.jl diff --git a/ext/OceananigansReactantExt/OceananigansReactantExt.jl b/ext/OceananigansReactantExt/OceananigansReactantExt.jl index 4a4be1f9de..03c3c7d600 100644 --- a/ext/OceananigansReactantExt/OceananigansReactantExt.jl +++ b/ext/OceananigansReactantExt/OceananigansReactantExt.jl @@ -8,6 +8,9 @@ using .Architectures include("TimeSteppers.jl") using .TimeSteppers +include("Simulations.jl") +using .Simulations + # These are additional modules that may need to be Reactantified in the future: # # include("Utils.jl") diff --git a/ext/OceananigansReactantExt/Simulations.jl b/ext/OceananigansReactantExt/Simulations.jl new file mode 100644 index 0000000000..74b185edcb --- /dev/null +++ b/ext/OceananigansReactantExt/Simulations.jl @@ -0,0 +1,54 @@ +module Simulations + +using Reactant +using Oceananigans + +using OrderedCollections: OrderedDict + +using ..Architectures: ReactantState +using Oceananigans: AbstractModel +using Oceananigans.Architectures: architecture + +using Oceananigans.Simulations: validate_Δt, stop_iteration_exceeded, AbstractDiagnostic, AbstractOutputWriter +import Oceananigans.Simulations: Simulation, aligned_time_step + +const ReactantModel = AbstractModel{<:Any, <:ReactantState} +const ReactantSimulation = Simulation{<:ReactantModel} + +aligned_time_step(::ReactantSimulation, Δt) = Δt + +function Simulation(model::ReactantModel; Δt, + verbose = true, + stop_iteration = Inf, + wall_time_limit = Inf, + minimum_relative_step = 0) + + Δt = validate_Δt(Δt, architecture(model)) + + diagnostics = OrderedDict{Symbol, AbstractDiagnostic}() + output_writers = OrderedDict{Symbol, AbstractOutputWriter}() + callbacks = OrderedDict{Symbol, Callback}() + + callbacks[:stop_iteration_exceeded] = Callback(stop_iteration_exceeded) + + # Convert numbers to floating point; otherwise preserve type (eg for DateTime types) + # TODO: implement TT = timetype(model) and FT = eltype(model) + TT = eltype(model) + Δt = Δt isa Number ? TT(Δt) : Δt + + return Simulation(model, + Δt, + Float64(stop_iteration), + nothing, # disallow stop_time + Float64(wall_time_limit), + diagnostics, + output_writers, + callbacks, + 0.0, + false, + false, + verbose, + Float64(minimum_relative_step)) +end + +end # module diff --git a/ext/OceananigansReactantExt/TimeSteppers.jl b/ext/OceananigansReactantExt/TimeSteppers.jl index ba960e4376..225a4eca1f 100644 --- a/ext/OceananigansReactantExt/TimeSteppers.jl +++ b/ext/OceananigansReactantExt/TimeSteppers.jl @@ -6,7 +6,7 @@ using Oceananigans using Oceananigans.Grids: AbstractGrid using ..Architectures: ReactantState -import Oceananigans.TimeSteppers: Clock +import Oceananigans.TimeSteppers: Clock, unit_time const ReactantGrid{FT, TX, TY, TZ} = AbstractGrid{FT, TX, TY, TZ, <:ReactantState} where {FT, TX, TY, TZ} @@ -15,8 +15,8 @@ function Clock(grid::ReactantGrid) t = ConcreteRNumber(zero(FT)) iter = ConcreteRNumber(0) stage = ConcreteRNumber(0) - last_Δt = ConcreteRNumber(zero(FT)) - last_stage_Δt = ConcreteRNumber(zero(FT)) + last_Δt = zero(FT) + last_stage_Δt = zero(FT) return Clock(; time=t, iteration=iter, stage, last_Δt, last_stage_Δt) end diff --git a/src/Models/HydrostaticFreeSurfaceModels/hydrostatic_free_surface_model.jl b/src/Models/HydrostaticFreeSurfaceModels/hydrostatic_free_surface_model.jl index 28f6f3e10b..93d87f9fd6 100644 --- a/src/Models/HydrostaticFreeSurfaceModels/hydrostatic_free_surface_model.jl +++ b/src/Models/HydrostaticFreeSurfaceModels/hydrostatic_free_surface_model.jl @@ -26,7 +26,7 @@ const ParticlesOrNothing = Union{Nothing, AbstractLagrangianParticles} const AbstractBGCOrNothing = Union{Nothing, AbstractBiogeochemistry} mutable struct HydrostaticFreeSurfaceModel{TS, E, A<:AbstractArchitecture, S, - G, T, V, B, R, F, P, BGC, U, C, Φ, K, AF, Z} <: AbstractModel{TS} + G, T, V, B, R, F, P, BGC, U, C, Φ, K, AF, Z} <: AbstractModel{TS, A} architecture :: A # Computer `Architecture` on which `Model` is run grid :: G # Grid of physical points on which `Model` is solved diff --git a/src/Models/NonhydrostaticModels/nonhydrostatic_model.jl b/src/Models/NonhydrostaticModels/nonhydrostatic_model.jl index baa79e56bb..3c0f7b56e2 100644 --- a/src/Models/NonhydrostaticModels/nonhydrostatic_model.jl +++ b/src/Models/NonhydrostaticModels/nonhydrostatic_model.jl @@ -30,7 +30,7 @@ const AbstractBGCOrNothing = Union{Nothing, AbstractBiogeochemistry} struct DefaultHydrostaticPressureAnomaly end mutable struct NonhydrostaticModel{TS, E, A<:AbstractArchitecture, G, T, B, R, SD, U, C, Φ, F, - V, S, K, BG, P, BGC, AF} <: AbstractModel{TS} + V, S, K, BG, P, BGC, AF} <: AbstractModel{TS, A} architecture :: A # Computer `Architecture` on which `Model` is run grid :: G # Grid of physical points on which `Model` is solved diff --git a/src/Models/ShallowWaterModels/shallow_water_model.jl b/src/Models/ShallowWaterModels/shallow_water_model.jl index 6a839bc5ab..29c9060889 100644 --- a/src/Models/ShallowWaterModels/shallow_water_model.jl +++ b/src/Models/ShallowWaterModels/shallow_water_model.jl @@ -36,7 +36,7 @@ function ShallowWaterSolutionFields(grid, bcs, prognostic_names) return NamedTuple{prognostic_names[1:3]}((u, v, h)) end -mutable struct ShallowWaterModel{G, A<:AbstractArchitecture, T, GR, V, U, R, F, E, B, Q, C, K, TS, FR} <: AbstractModel{TS} +mutable struct ShallowWaterModel{G, A<:AbstractArchitecture, T, GR, V, U, R, F, E, B, Q, C, K, TS, FR} <: AbstractModel{TS, A} grid :: G # Grid of physical points on which `Model` is solved architecture :: A # Computer `Architecture` on which `Model` is run clock :: Clock{T} # Tracks iteration number and simulation time of `Model` diff --git a/src/Oceananigans.jl b/src/Oceananigans.jl index f413aaf065..f6a269d2f6 100644 --- a/src/Oceananigans.jl +++ b/src/Oceananigans.jl @@ -157,7 +157,7 @@ const defaults = Defaults() Abstract supertype for models. """ -abstract type AbstractModel{TS} end +abstract type AbstractModel{TS, A} end """ AbstractDiagnostic diff --git a/src/Simulations/run.jl b/src/Simulations/run.jl index 07ed3277a6..e7993de817 100644 --- a/src/Simulations/run.jl +++ b/src/Simulations/run.jl @@ -47,7 +47,8 @@ function aligned_time_step(sim::Simulation, Δt) aligned_Δt = schedule_aligned_time_step(sim, aligned_Δt) # Align time step with simulation stop time - aligned_Δt = min(aligned_Δt, unit_time(sim.stop_time - clock.time)) + time_left = unit_time(sim.stop_time - clock.time) + aligned_Δt = min(aligned_Δt, time_left) # Temporary fix for https://github.com/CliMA/Oceananigans.jl/issues/1280 aligned_Δt = aligned_Δt <= 0 ? Δt : aligned_Δt diff --git a/src/Utils/schedules.jl b/src/Utils/schedules.jl index 23f2e2bab4..bf13f72d9c 100644 --- a/src/Utils/schedules.jl +++ b/src/Utils/schedules.jl @@ -105,7 +105,6 @@ For example, * `IterationInterval(100, offset=-1)` actuates at iterations `[99, 199, 299, ...]`. """ IterationInterval(interval; offset=0) = IterationInterval(interval, offset) - (schedule::IterationInterval)(model) = (model.clock.iteration - schedule.offset) % schedule.interval == 0 next_actuation_time(schedule::IterationInterval) = Inf diff --git a/test/test_reactant.jl b/test/test_reactant.jl index f242b80533..914e17bcb8 100644 --- a/test/test_reactant.jl +++ b/test/test_reactant.jl @@ -64,7 +64,6 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw) # Reactant time now: r_simulation = Simulation(r_model; Δt, stop_iteration, verbose=false) - pop!(r_simulation.callbacks, :nan_checker) r_run! = @compile sync = true run!(r_simulation) r_run!(r_simulation) From ff4480d48a689fc57b6703f7ded2b0a6ce0b2a94 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sun, 16 Feb 2025 21:32:59 -0800 Subject: [PATCH 05/19] Extend initialize! --- ext/OceananigansReactantExt/Simulations.jl | 67 +++++++++++++++++++++- src/Simulations/run.jl | 1 + 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/ext/OceananigansReactantExt/Simulations.jl b/ext/OceananigansReactantExt/Simulations.jl index 74b185edcb..0705c5d11d 100644 --- a/ext/OceananigansReactantExt/Simulations.jl +++ b/ext/OceananigansReactantExt/Simulations.jl @@ -6,17 +6,78 @@ using Oceananigans using OrderedCollections: OrderedDict using ..Architectures: ReactantState -using Oceananigans: AbstractModel +using Oceananigans: AbstractModel, run_diagnostic! using Oceananigans.Architectures: architecture +using Oceananigans.TimeSteppers: update_state! +using Oceananigans.OutputWriters: write_output! -using Oceananigans.Simulations: validate_Δt, stop_iteration_exceeded, AbstractDiagnostic, AbstractOutputWriter -import Oceananigans.Simulations: Simulation, aligned_time_step +using Oceananigans.Simulations: + validate_Δt, + stop_iteration_exceeded, + add_dependencies!, + reset!, + AbstractDiagnostic, + AbstractOutputWriter + +import Oceananigans.Simulations: Simulation, aligned_time_step, initialize! const ReactantModel = AbstractModel{<:Any, <:ReactantState} const ReactantSimulation = Simulation{<:ReactantModel} aligned_time_step(::ReactantSimulation, Δt) = Δt +function initialize!(sim::ReactantSimulation) + if sim.verbose + @info "Initializing simulation..." + start_time = time_ns() + end + + model = sim.model + clock = model.clock + + update_state!(model) + + # Output and diagnostics initialization + [add_dependencies!(sim.diagnostics, writer) for writer in values(sim.output_writers)] + + # Initialize schedules + scheduled_activities = Iterators.flatten((values(sim.diagnostics), + values(sim.callbacks), + values(sim.output_writers))) + + for activity in scheduled_activities + initialize!(activity.schedule, sim.model) + end + + # Reset! the model time-stepper, evaluate all diagnostics, and write all output at first iteration + @trace if clock.iteration == 0 + reset!(timestepper(sim.model)) + + # Initialize schedules and run diagnostics, callbacks, and output writers + for diag in values(sim.diagnostics) + run_diagnostic!(diag, model) + end + + for callback in values(sim.callbacks) + callback.callsite isa TimeStepCallsite && callback(sim) + end + + for writer in values(sim.output_writers) + writer.schedule(sim.model) + write_output!(writer, model) + end + end + + sim.initialized = true + + if sim.verbose + initialization_time = prettytime(1e-9 * (time_ns() - start_time)) + @info " ... simulation initialization complete ($initialization_time)" + end + + return nothing +end + function Simulation(model::ReactantModel; Δt, verbose = true, stop_iteration = Inf, diff --git a/src/Simulations/run.jl b/src/Simulations/run.jl index e7993de817..27ee0808a1 100644 --- a/src/Simulations/run.jl +++ b/src/Simulations/run.jl @@ -242,3 +242,4 @@ function initialize!(sim::Simulation) return nothing end + From 6cf14a2156f47efd8b720aa6f399aa2b6c9cb59e Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Tue, 18 Feb 2025 11:54:16 -0700 Subject: [PATCH 06/19] Working through Reactant errors --- Project.toml | 2 +- ext/OceananigansReactantExt/Models.jl | 5 + .../OceananigansReactantExt.jl | 2 +- ext/OceananigansReactantExt/Simulations.jl | 115 -------------- .../Simulations/Simulations.jl | 29 ++++ .../Simulations/run.jl | 143 ++++++++++++++++++ .../Simulations/simulation.jl | 37 +++++ ext/OceananigansReactantExt/TimeSteppers.jl | 60 +++++++- src/TimeSteppers/quasi_adams_bashforth_2.jl | 2 + test/test_reactant.jl | 14 +- 10 files changed, 286 insertions(+), 123 deletions(-) create mode 100644 ext/OceananigansReactantExt/Models.jl delete mode 100644 ext/OceananigansReactantExt/Simulations.jl create mode 100644 ext/OceananigansReactantExt/Simulations/Simulations.jl create mode 100644 ext/OceananigansReactantExt/Simulations/run.jl create mode 100644 ext/OceananigansReactantExt/Simulations/simulation.jl diff --git a/Project.toml b/Project.toml index f609ffecd5..a6a4a3b5f6 100644 --- a/Project.toml +++ b/Project.toml @@ -73,7 +73,7 @@ OffsetArrays = "1.4" OrderedCollections = "1.1" Printf = "1.9" Random = "1.9" -Reactant = "0.2.25" +Reactant = "0.2.31" Rotations = "1.0" SeawaterPolynomials = "0.3.5" SparseArrays = "1.9" diff --git a/ext/OceananigansReactantExt/Models.jl b/ext/OceananigansReactantExt/Models.jl new file mode 100644 index 0000000000..22399451c7 --- /dev/null +++ b/ext/OceananigansReactantExt/Models.jl @@ -0,0 +1,5 @@ +module Models + + + +end # module diff --git a/ext/OceananigansReactantExt/OceananigansReactantExt.jl b/ext/OceananigansReactantExt/OceananigansReactantExt.jl index 03c3c7d600..07d0875481 100644 --- a/ext/OceananigansReactantExt/OceananigansReactantExt.jl +++ b/ext/OceananigansReactantExt/OceananigansReactantExt.jl @@ -8,7 +8,7 @@ using .Architectures include("TimeSteppers.jl") using .TimeSteppers -include("Simulations.jl") +include("Simulations/Simulations.jl") using .Simulations # These are additional modules that may need to be Reactantified in the future: diff --git a/ext/OceananigansReactantExt/Simulations.jl b/ext/OceananigansReactantExt/Simulations.jl deleted file mode 100644 index 0705c5d11d..0000000000 --- a/ext/OceananigansReactantExt/Simulations.jl +++ /dev/null @@ -1,115 +0,0 @@ -module Simulations - -using Reactant -using Oceananigans - -using OrderedCollections: OrderedDict - -using ..Architectures: ReactantState -using Oceananigans: AbstractModel, run_diagnostic! -using Oceananigans.Architectures: architecture -using Oceananigans.TimeSteppers: update_state! -using Oceananigans.OutputWriters: write_output! - -using Oceananigans.Simulations: - validate_Δt, - stop_iteration_exceeded, - add_dependencies!, - reset!, - AbstractDiagnostic, - AbstractOutputWriter - -import Oceananigans.Simulations: Simulation, aligned_time_step, initialize! - -const ReactantModel = AbstractModel{<:Any, <:ReactantState} -const ReactantSimulation = Simulation{<:ReactantModel} - -aligned_time_step(::ReactantSimulation, Δt) = Δt - -function initialize!(sim::ReactantSimulation) - if sim.verbose - @info "Initializing simulation..." - start_time = time_ns() - end - - model = sim.model - clock = model.clock - - update_state!(model) - - # Output and diagnostics initialization - [add_dependencies!(sim.diagnostics, writer) for writer in values(sim.output_writers)] - - # Initialize schedules - scheduled_activities = Iterators.flatten((values(sim.diagnostics), - values(sim.callbacks), - values(sim.output_writers))) - - for activity in scheduled_activities - initialize!(activity.schedule, sim.model) - end - - # Reset! the model time-stepper, evaluate all diagnostics, and write all output at first iteration - @trace if clock.iteration == 0 - reset!(timestepper(sim.model)) - - # Initialize schedules and run diagnostics, callbacks, and output writers - for diag in values(sim.diagnostics) - run_diagnostic!(diag, model) - end - - for callback in values(sim.callbacks) - callback.callsite isa TimeStepCallsite && callback(sim) - end - - for writer in values(sim.output_writers) - writer.schedule(sim.model) - write_output!(writer, model) - end - end - - sim.initialized = true - - if sim.verbose - initialization_time = prettytime(1e-9 * (time_ns() - start_time)) - @info " ... simulation initialization complete ($initialization_time)" - end - - return nothing -end - -function Simulation(model::ReactantModel; Δt, - verbose = true, - stop_iteration = Inf, - wall_time_limit = Inf, - minimum_relative_step = 0) - - Δt = validate_Δt(Δt, architecture(model)) - - diagnostics = OrderedDict{Symbol, AbstractDiagnostic}() - output_writers = OrderedDict{Symbol, AbstractOutputWriter}() - callbacks = OrderedDict{Symbol, Callback}() - - callbacks[:stop_iteration_exceeded] = Callback(stop_iteration_exceeded) - - # Convert numbers to floating point; otherwise preserve type (eg for DateTime types) - # TODO: implement TT = timetype(model) and FT = eltype(model) - TT = eltype(model) - Δt = Δt isa Number ? TT(Δt) : Δt - - return Simulation(model, - Δt, - Float64(stop_iteration), - nothing, # disallow stop_time - Float64(wall_time_limit), - diagnostics, - output_writers, - callbacks, - 0.0, - false, - false, - verbose, - Float64(minimum_relative_step)) -end - -end # module diff --git a/ext/OceananigansReactantExt/Simulations/Simulations.jl b/ext/OceananigansReactantExt/Simulations/Simulations.jl new file mode 100644 index 0000000000..29f6de1f47 --- /dev/null +++ b/ext/OceananigansReactantExt/Simulations/Simulations.jl @@ -0,0 +1,29 @@ +module Simulations + +using Reactant +using Oceananigans + +using OrderedCollections: OrderedDict + +using ..Architectures: ReactantState +using ..TimeSteppers: ReactantModel + +using Oceananigans: run_diagnostic! +using Oceananigans.Architectures: architecture +using Oceananigans.TimeSteppers: update_state! +using Oceananigans.OutputWriters: write_output! + +using Oceananigans.Simulations: + validate_Δt, + stop_iteration_exceeded, + add_dependencies!, + reset!, + AbstractDiagnostic, + AbstractOutputWriter + +import Oceananigans.Simulations: Simulation, aligned_time_step, initialize! + +include("simulation.jl") +include("run.jl") + +end # module diff --git a/ext/OceananigansReactantExt/Simulations/run.jl b/ext/OceananigansReactantExt/Simulations/run.jl new file mode 100644 index 0000000000..b11198a889 --- /dev/null +++ b/ext/OceananigansReactantExt/Simulations/run.jl @@ -0,0 +1,143 @@ +using Oceananigans.Simulations: ModelCallsite +using Oceananigans: TimeStepCallsite, TendencyCallsite, UpdateStateCallsite +import Oceananigans.TimeSteppers: time_step! + +aligned_time_step(::ReactantSimulation, Δt) = Δt + +function initialize!(sim::ReactantSimulation) + #= + if sim.verbose + @info "Initializing simulation..." + start_time = time_ns() + end + =# + + model = sim.model + clock = model.clock + + update_state!(model) + + #= + # Output and diagnostics initialization + [add_dependencies!(sim.diagnostics, writer) for writer in values(sim.output_writers)] + + # Initialize schedules + scheduled_activities = Iterators.flatten((values(sim.diagnostics), + values(sim.callbacks), + values(sim.output_writers))) + + for activity in scheduled_activities + initialize!(activity.schedule, sim.model) + end + =# + + #= + # Reset! the model time-stepper, evaluate all diagnostics, and write all output at first iteration + @trace if clock.iteration == 0 + reset!(timestepper(sim.model)) + + # Initialize schedules and run diagnostics, callbacks, and output writers + for diag in values(sim.diagnostics) + run_diagnostic!(diag, model) + end + + for callback in values(sim.callbacks) + callback.callsite isa TimeStepCallsite && callback(sim) + end + + for writer in values(sim.output_writers) + writer.schedule(sim.model) + write_output!(writer, model) + end + end + =# + + sim.initialized = true + + #= + if sim.verbose + initialization_time = prettytime(1e-9 * (time_ns() - start_time)) + @info " ... simulation initialization complete ($initialization_time)" + end + =# + + return nothing +end + +""" Step `sim`ulation forward by one time step. """ +function time_step!(sim::ReactantSimulation) + + start_time_step = time_ns() + model_callbacks = Tuple(cb for cb in values(sim.callbacks) if cb.callsite isa ModelCallsite) + Δt = aligned_time_step(sim, sim.Δt) + + if !(sim.initialized) # execute initialization step + initialize!(sim) + initialize!(sim.model) + + if sim.running # check that initialization didn't stop time-stepping + if sim.verbose + @info "Executing initial time step..." + start_time = time_ns() + end + + # Take first time-step + time_step!(sim.model, Δt, callbacks=model_callbacks) + + if sim.verbose + elapsed_initial_step_time = prettytime(1e-9 * (time_ns() - start_time)) + @info " ... initial time step complete ($elapsed_initial_step_time)." + end + else + @warn "Simulation stopped during initialization." + end + + else # business as usual... + if Δt < sim.minimum_relative_step * sim.Δt + next_time = sim.model.clock.time + Δt + @warn "Resetting clock to $next_time and skipping time step of size Δt = $Δt" + sim.model.clock.time = next_time + else + time_step!(sim.model, Δt, callbacks=model_callbacks) + end + end + + for callback in values(sim.callbacks) + need_to_call = callback.schedule(sim.model) + @trace if need_to_call + callback(sim) + end + + #= + @trace if callback.callsite isa TimeStepCallsite + if callback.schedule(sim.model) + callback(sim) + else + nothing + end + else + nothing + end + =# + end + + #= + # Callbacks and callback-like things + for diag in values(sim.diagnostics) + diag.schedule(sim.model) && run_diagnostic!(diag, sim.model) + end + + + for writer in values(sim.output_writers) + writer.schedule(sim.model) && write_output!(writer, sim.model) + end + + end_time_step = time_ns() + + # Increment the wall clock + sim.run_wall_time += 1e-9 * (end_time_step - start_time_step) + =# + + return nothing +end + diff --git a/ext/OceananigansReactantExt/Simulations/simulation.jl b/ext/OceananigansReactantExt/Simulations/simulation.jl new file mode 100644 index 0000000000..c8f7bf8fae --- /dev/null +++ b/ext/OceananigansReactantExt/Simulations/simulation.jl @@ -0,0 +1,37 @@ +const ReactantSimulation = Simulation{<:ReactantModel} + +function Simulation(model::ReactantModel; Δt, + verbose = true, + stop_iteration = Inf, + wall_time_limit = Inf, + minimum_relative_step = 0) + + Δt = validate_Δt(Δt, architecture(model)) + + diagnostics = OrderedDict{Symbol, AbstractDiagnostic}() + output_writers = OrderedDict{Symbol, AbstractOutputWriter}() + callbacks = OrderedDict{Symbol, Callback}() + + callbacks[:stop_iteration_exceeded] = Callback(stop_iteration_exceeded) + + # Convert numbers to floating point; otherwise preserve type (eg for DateTime types) + # TODO: implement TT = timetype(model) and FT = eltype(model) + TT = eltype(model) + Δt = Δt isa Number ? TT(Δt) : Δt + + return Simulation(model, + Δt, + Float64(stop_iteration), + nothing, # disallow stop_time + Float64(wall_time_limit), + diagnostics, + output_writers, + callbacks, + 0.0, + false, + false, + verbose, + Float64(minimum_relative_step)) +end + + diff --git a/ext/OceananigansReactantExt/TimeSteppers.jl b/ext/OceananigansReactantExt/TimeSteppers.jl index 225a4eca1f..87e8258d51 100644 --- a/ext/OceananigansReactantExt/TimeSteppers.jl +++ b/ext/OceananigansReactantExt/TimeSteppers.jl @@ -1,14 +1,26 @@ module TimeSteppers +using ..Architectures: ReactantState + using Reactant using Oceananigans +using Oceananigans: AbstractModel using Oceananigans.Grids: AbstractGrid -using ..Architectures: ReactantState +using Oceananigans.Utils: @apply_regionally, apply_regionally! +using Oceananigans.TimeSteppers: + update_state!, + tick!, + calculate_pressure_correction!, + correct_velocities_and_store_tendencies!, + step_lagrangian_particles!, + QuasiAdamsBashforth2TimeStepper, + ab2_step! -import Oceananigans.TimeSteppers: Clock, unit_time +import Oceananigans.TimeSteppers: Clock, unit_time, time_step! const ReactantGrid{FT, TX, TY, TZ} = AbstractGrid{FT, TX, TY, TZ, <:ReactantState} where {FT, TX, TY, TZ} +const ReactantModel{TS} = AbstractModel{TS, <:ReactantState} where TS function Clock(grid::ReactantGrid) FT = Float64 # may change in the future @@ -20,5 +32,49 @@ function Clock(grid::ReactantGrid) return Clock(; time=t, iteration=iter, stage, last_Δt, last_stage_Δt) end +function time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt; + callbacks=[], euler=false) + + #= + # Be paranoid and update state at iteration 0 + @trace if model.clock.iteration == 0 + update_state!(model, callbacks; compute_tendencies=true) + end + =# + + # Take an euler step if: + # * We detect that the time-step size has changed. + # * We detect that this is the "first" time-step, which means we + # need to take an euler step. Note that model.clock.last_Δt is + # initialized as Inf + # * The user has passed euler=true to time_step! + euler = euler || (Δt != model.clock.last_Δt) + + # If euler, then set χ = -0.5 + minus_point_five = convert(eltype(model.grid), -0.5) + ab2_timestepper = model.timestepper + χ = ifelse(euler, minus_point_five, ab2_timestepper.χ) + χ₀ = ab2_timestepper.χ # Save initial value + ab2_timestepper.χ = χ + + # Full step for tracers, fractional step for velocities. + ab2_step!(model, Δt) + + tick!(model.clock, Δt) + model.clock.last_Δt = Δt + model.clock.last_stage_Δt = Δt # just one stage + + calculate_pressure_correction!(model, Δt) + @apply_regionally correct_velocities_and_store_tendencies!(model, Δt) + + update_state!(model, callbacks; compute_tendencies=true) + step_lagrangian_particles!(model, Δt) + + # Return χ to initial value + ab2_timestepper.χ = χ₀ + + return nothing +end + end # module diff --git a/src/TimeSteppers/quasi_adams_bashforth_2.jl b/src/TimeSteppers/quasi_adams_bashforth_2.jl index d099ba7a8d..9f39c94408 100644 --- a/src/TimeSteppers/quasi_adams_bashforth_2.jl +++ b/src/TimeSteppers/quasi_adams_bashforth_2.jl @@ -74,6 +74,8 @@ The steps of the Quasi-Adams-Bashforth second-order (AB2) algorithm are: function time_step!(model::AbstractModel{<:QuasiAdamsBashforth2TimeStepper}, Δt; callbacks=[], euler=false) + @info "Not using Reactantified time_step!" + Δt == 0 && @warn "Δt == 0 may cause model blowup!" # Be paranoid and update state at iteration 0 diff --git a/test/test_reactant.jl b/test/test_reactant.jl index 914e17bcb8..3b0e35d171 100644 --- a/test/test_reactant.jl +++ b/test/test_reactant.jl @@ -59,14 +59,21 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw) simulation = Simulation(model; Δt, stop_iteration, verbose=false) run!(simulation) + @info " After running 3 time steps, the non-reactant model:" @test iteration(simulation) == stop_iteration @test time(simulation) == 3Δt # Reactant time now: r_simulation = Simulation(r_model; Δt, stop_iteration, verbose=false) - r_run! = @compile sync = true run!(r_simulation) + @info " Compiling r_run!:" + r_run! = @compile sync=true run!(r_simulation) + # r_run! = @compile run!(r_simulation) + + @info " Executing r_run!:" r_run!(r_simulation) + + @info " After running 3 time steps, the reactant model:" @test iteration(r_simulation) == stop_iteration @test time(r_simulation) == 3Δt @@ -75,7 +82,6 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw) @test iteration(r_simulation) == iteration(simulation) @test time(r_simulation) == time(simulation) - @info " After running 3 time steps:" @show maximum(abs, parent(u)) @show maximum(abs, parent(v)) @show maximum(abs, parent(w)) @@ -152,7 +158,7 @@ end end @testset "Reactant Super Simple Simulation Tests" begin - nonhydrostatic_model_kw = (; advection=WENO()) + # nonhydrostatic_model_kw = (; advection=WENO()) hydrostatic_model_kw = (; momentum_advection=WENO()) Nx, Ny, Nz = (10, 10, 10) # number of cells halo = (7, 7, 7) @@ -162,7 +168,7 @@ end lat_lon_kw = (; size=(Nx, Ny, Nz), halo, longitude, latitude, z) rectilinear_kw = (; size=(Nx, Ny, Nz), halo, x=(0, 1), y=(0, 1), z=(0, 1)) - # FFTs are not supported by Reactant so we don't run this test: + # We don't yet support NonhydrostaticModel: # @info "Testing RectilinearGrid + NonhydrostaticModel Reactant correctness" # test_reactant_model_correctness(RectilinearGrid, NonhydrostaticModel, rectilinear_kw, nonhydrostatic_model_kw) From 037d15fd171ccaf73587c9183e09a98e20726056 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sat, 22 Feb 2025 12:32:11 -0700 Subject: [PATCH 07/19] Trying a few things plus extend stop_iteration_exceeded --- .../Simulations/Simulations.jl | 6 +++- .../Simulations/run.jl | 4 +++ .../Simulations/simulation.jl | 35 +++++++++++++++---- ext/OceananigansReactantExt/TimeSteppers.jl | 4 ++- 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/ext/OceananigansReactantExt/Simulations/Simulations.jl b/ext/OceananigansReactantExt/Simulations/Simulations.jl index 29f6de1f47..d604e93e1e 100644 --- a/ext/OceananigansReactantExt/Simulations/Simulations.jl +++ b/ext/OceananigansReactantExt/Simulations/Simulations.jl @@ -21,7 +21,11 @@ using Oceananigans.Simulations: AbstractDiagnostic, AbstractOutputWriter -import Oceananigans.Simulations: Simulation, aligned_time_step, initialize! +import Oceananigans.Simulations: + Simulation, + aligned_time_step, + initialize!, + stop_iteration_exceeded include("simulation.jl") include("run.jl") diff --git a/ext/OceananigansReactantExt/Simulations/run.jl b/ext/OceananigansReactantExt/Simulations/run.jl index b11198a889..0c1ed21d75 100644 --- a/ext/OceananigansReactantExt/Simulations/run.jl +++ b/ext/OceananigansReactantExt/Simulations/run.jl @@ -93,6 +93,8 @@ function time_step!(sim::ReactantSimulation) end else # business as usual... + time_step!(sim.model, Δt, callbacks=model_callbacks) + #= if Δt < sim.minimum_relative_step * sim.Δt next_time = sim.model.clock.time + Δt @warn "Resetting clock to $next_time and skipping time step of size Δt = $Δt" @@ -100,6 +102,7 @@ function time_step!(sim::ReactantSimulation) else time_step!(sim.model, Δt, callbacks=model_callbacks) end + =# end for callback in values(sim.callbacks) @@ -120,6 +123,7 @@ function time_step!(sim::ReactantSimulation) end =# end + =# #= # Callbacks and callback-like things diff --git a/ext/OceananigansReactantExt/Simulations/simulation.jl b/ext/OceananigansReactantExt/Simulations/simulation.jl index c8f7bf8fae..c3d55bb921 100644 --- a/ext/OceananigansReactantExt/Simulations/simulation.jl +++ b/ext/OceananigansReactantExt/Simulations/simulation.jl @@ -2,9 +2,7 @@ const ReactantSimulation = Simulation{<:ReactantModel} function Simulation(model::ReactantModel; Δt, verbose = true, - stop_iteration = Inf, - wall_time_limit = Inf, - minimum_relative_step = 0) + stop_iteration = Inf) Δt = validate_Δt(Δt, architecture(model)) @@ -19,11 +17,13 @@ function Simulation(model::ReactantModel; Δt, TT = eltype(model) Δt = Δt isa Number ? TT(Δt) : Δt + stop_iteration = ConcreteRNumber(Float64(stop_iteration)) + return Simulation(model, Δt, - Float64(stop_iteration), + stop_iteration, nothing, # disallow stop_time - Float64(wall_time_limit), + Inf, diagnostics, output_writers, callbacks, @@ -31,7 +31,30 @@ function Simulation(model::ReactantModel; Δt, false, false, verbose, - Float64(minimum_relative_step)) + 0.0) +end + +function stop_iteration_exceeded(sim::ReactantSimulation) + #= + @trace if sim.model.clock.iteration >= sim.stop_iteration + #= + if sim.verbose + msg = string("Model iteration ", + iteration(sim), + " equals or exceeds stop iteration ", + Int(sim.stop_iteration), + ".") + + @info wall_time_msg(sim) + @info msg + end + =# + + sim.running = false + end + =# + + return nothing end diff --git a/ext/OceananigansReactantExt/TimeSteppers.jl b/ext/OceananigansReactantExt/TimeSteppers.jl index 87e8258d51..0727c71437 100644 --- a/ext/OceananigansReactantExt/TimeSteppers.jl +++ b/ext/OceananigansReactantExt/TimeSteppers.jl @@ -48,7 +48,9 @@ function time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt # need to take an euler step. Note that model.clock.last_Δt is # initialized as Inf # * The user has passed euler=true to time_step! - euler = euler || (Δt != model.clock.last_Δt) + @trace if Δt != model.clock.last_Δt + euler = true + end # If euler, then set χ = -0.5 minus_point_five = convert(eltype(model.grid), -0.5) From 811dca59db5dd68baa42a119496887db2c1ccd87 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Mon, 24 Feb 2025 12:41:37 -0700 Subject: [PATCH 08/19] Updates --- ext/OceananigansReactantExt/Architectures.jl | 2 ++ .../Simulations/run.jl | 19 +++++++++++++------ .../Simulations/simulation.jl | 3 ++- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/ext/OceananigansReactantExt/Architectures.jl b/ext/OceananigansReactantExt/Architectures.jl index 7483d06b4e..070018e23d 100644 --- a/ext/OceananigansReactantExt/Architectures.jl +++ b/ext/OceananigansReactantExt/Architectures.jl @@ -3,6 +3,8 @@ module Architectures using Reactant using Oceananigans +using Reactant: AnyConcreteRArray + import Oceananigans.Architectures: device, architecture, array_type, on_architecture import Oceananigans.Architectures: unified_array, ReactantState, device_copy_to! diff --git a/ext/OceananigansReactantExt/Simulations/run.jl b/ext/OceananigansReactantExt/Simulations/run.jl index 0c1ed21d75..a903aaf0e4 100644 --- a/ext/OceananigansReactantExt/Simulations/run.jl +++ b/ext/OceananigansReactantExt/Simulations/run.jl @@ -14,7 +14,6 @@ function initialize!(sim::ReactantSimulation) model = sim.model clock = model.clock - update_state!(model) #= @@ -67,7 +66,7 @@ end """ Step `sim`ulation forward by one time step. """ function time_step!(sim::ReactantSimulation) - start_time_step = time_ns() + #start_time_step = time_ns() model_callbacks = Tuple(cb for cb in values(sim.callbacks) if cb.callsite isa ModelCallsite) Δt = aligned_time_step(sim, sim.Δt) @@ -77,16 +76,16 @@ function time_step!(sim::ReactantSimulation) if sim.running # check that initialization didn't stop time-stepping if sim.verbose - @info "Executing initial time step..." - start_time = time_ns() + # @info "Executing initial time step..." + #start_time = time_ns() end # Take first time-step time_step!(sim.model, Δt, callbacks=model_callbacks) if sim.verbose - elapsed_initial_step_time = prettytime(1e-9 * (time_ns() - start_time)) - @info " ... initial time step complete ($elapsed_initial_step_time)." + #elapsed_initial_step_time = prettytime(1e-9 * (time_ns() - start_time)) + # @info " ... initial time step complete ($elapsed_initial_step_time)." end else @warn "Simulation stopped during initialization." @@ -105,6 +104,14 @@ function time_step!(sim::ReactantSimulation) =# end + stop_sim = iteration(sim) >= sim.stop_iteration + @trace if stop_sim + sim.running = false + else + nothing + end + + #= for callback in values(sim.callbacks) need_to_call = callback.schedule(sim.model) @trace if need_to_call diff --git a/ext/OceananigansReactantExt/Simulations/simulation.jl b/ext/OceananigansReactantExt/Simulations/simulation.jl index c3d55bb921..8b1c0a8ca5 100644 --- a/ext/OceananigansReactantExt/Simulations/simulation.jl +++ b/ext/OceananigansReactantExt/Simulations/simulation.jl @@ -17,7 +17,8 @@ function Simulation(model::ReactantModel; Δt, TT = eltype(model) Δt = Δt isa Number ? TT(Δt) : Δt - stop_iteration = ConcreteRNumber(Float64(stop_iteration)) + #stop_iteration = ConcreteRNumber(Float64(stop_iteration)) + stop_iteration = Float64(stop_iteration) return Simulation(model, Δt, From 912db0e93b5a080012fa1c6f1f82e3af48ef43b0 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Tue, 25 Feb 2025 17:28:33 -0700 Subject: [PATCH 09/19] Add test code --- test/test_reactant.jl | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/test/test_reactant.jl b/test/test_reactant.jl index 3b0e35d171..c9e41d4034 100644 --- a/test/test_reactant.jl +++ b/test/test_reactant.jl @@ -17,6 +17,42 @@ using KernelAbstractions: @kernel, @index using GPUArrays using Random +#= +using Reactant +using Reactant.ReactantCore + +mutable struct TestClock{I} + iteration :: I +end + +mutable struct TestSimulation{C, I, B} + clock :: C + stop_iteration :: I + running :: B +end + +function step!(sim) + cond = sim.clock.iteration >= sim.stop_iteration + @trace if cond + sim.running = false + else + sim.clock.iteration += 1 # time step + end + return sim # note, this function returns sim which is used as an argument for the next while-loop iteration. +end + +function test_run!(sim) + ReactantCore.traced_while(sim->sim.running, step!, (sim, )) +end + +clock = TestClock(ConcreteRNumber(0)) +simulation = TestSimulation(clock, ConcreteRNumber(3), ConcreteRNumber(true)) +# @code_hlo optimize=false test_run!(simulation) + +r_run! = @compile sync=true test_run!(simulation) +r_run!(simulation) +=# + function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw) r_arch = ReactantState() r_grid = GridType(r_arch; grid_kw...) @@ -114,6 +150,7 @@ end @inbounds f[i, j, k] += 1 end +#= @testset "Reactanigans unit tests" begin @info "Performing Reactanigans unit tests..." arch = ReactantState() @@ -156,6 +193,7 @@ end @test cd[1, 2, 3] == 2 * (x[1] + y[2] * z[3]) end end +=# @testset "Reactant Super Simple Simulation Tests" begin # nonhydrostatic_model_kw = (; advection=WENO()) From 70effaa9fed20b5fccf5a3b259e302eecc0aa8b7 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Thu, 27 Feb 2025 16:08:03 -0700 Subject: [PATCH 10/19] Finally get test_reactant to copmile for a long time --- Project.toml | 3 +- .../OceananigansReactantExt.jl | 10 +++ .../Simulations/run.jl | 69 ++++++++++--------- src/Fields/interpolate.jl | 1 + test/test_reactant.jl | 6 +- 5 files changed, 54 insertions(+), 35 deletions(-) diff --git a/Project.toml b/Project.toml index d4ae23f31a..67b45b7fa3 100644 --- a/Project.toml +++ b/Project.toml @@ -41,11 +41,12 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" [extensions] OceananigansEnzymeExt = "Enzyme" OceananigansMakieExt = ["MakieCore", "Makie"] -OceananigansReactantExt = ["Reactant", "KernelAbstractions"] +OceananigansReactantExt = ["Reactant", "KernelAbstractions", "ConstructionBase"] [compat] Adapt = "4.1.1" diff --git a/ext/OceananigansReactantExt/OceananigansReactantExt.jl b/ext/OceananigansReactantExt/OceananigansReactantExt.jl index 07d0875481..77a77f8383 100644 --- a/ext/OceananigansReactantExt/OceananigansReactantExt.jl +++ b/ext/OceananigansReactantExt/OceananigansReactantExt.jl @@ -1,6 +1,7 @@ module OceananigansReactantExt using Reactant +using Oceananigans include("Architectures.jl") using .Architectures @@ -11,6 +12,15 @@ using .TimeSteppers include("Simulations/Simulations.jl") using .Simulations +##### +##### Telling Reactant how to construct types +##### + +import ConstructionBase: constructorof + +constructorof(::Type{<:RectilinearGrid{FT, TX, TY, TZ}}) where {FT, TX, TY, TZ} = RectilinearGrid{TX, TY, TZ} +constructorof(::Type{<:VectorInvariant{N, FT, M}}) where {N, FT, M} = VectorInvariant{N, FT, M} + # These are additional modules that may need to be Reactantified in the future: # # include("Utils.jl") diff --git a/ext/OceananigansReactantExt/Simulations/run.jl b/ext/OceananigansReactantExt/Simulations/run.jl index a903aaf0e4..d57cf09c90 100644 --- a/ext/OceananigansReactantExt/Simulations/run.jl +++ b/ext/OceananigansReactantExt/Simulations/run.jl @@ -67,44 +67,20 @@ end function time_step!(sim::ReactantSimulation) #start_time_step = time_ns() - model_callbacks = Tuple(cb for cb in values(sim.callbacks) if cb.callsite isa ModelCallsite) - Δt = aligned_time_step(sim, sim.Δt) + # Δt = aligned_time_step(sim, sim.Δt) + Δt = sim.Δt if !(sim.initialized) # execute initialization step initialize!(sim) initialize!(sim.model) - - if sim.running # check that initialization didn't stop time-stepping - if sim.verbose - # @info "Executing initial time step..." - #start_time = time_ns() - end - - # Take first time-step - time_step!(sim.model, Δt, callbacks=model_callbacks) - - if sim.verbose - #elapsed_initial_step_time = prettytime(1e-9 * (time_ns() - start_time)) - # @info " ... initial time step complete ($elapsed_initial_step_time)." - end - else - @warn "Simulation stopped during initialization." - end - - else # business as usual... - time_step!(sim.model, Δt, callbacks=model_callbacks) - #= - if Δt < sim.minimum_relative_step * sim.Δt - next_time = sim.model.clock.time + Δt - @warn "Resetting clock to $next_time and skipping time step of size Δt = $Δt" - sim.model.clock.time = next_time - else - time_step!(sim.model, Δt, callbacks=model_callbacks) - end - =# end + # model_callbacks = Tuple(cb for cb in values(sim.callbacks) if cb.callsite isa ModelCallsite) + model_callbacks = tuple() + time_step!(sim.model, Δt, callbacks=model_callbacks) + stop_sim = iteration(sim) >= sim.stop_iteration + @trace if stop_sim sim.running = false else @@ -152,3 +128,34 @@ function time_step!(sim::ReactantSimulation) return nothing end +function run!(sim::ReactantSimulation; pickup=false) + + #= + start_run = time_ns() + + if we_want_to_pickup(pickup) + set!(sim, pickup) + end + =# + + sim.initialized = false + sim.running = true + sim.run_wall_time = 0.0 + + while sim.running + time_step!(sim) + end + + #= + for callback in values(sim.callbacks) + finalize!(callback, sim) + end + + # Increment the wall clock + end_run = time_ns() + sim.run_wall_time += 1e-9 * (end_run - start_run) + =# + + return nothing +end + diff --git a/src/Fields/interpolate.jl b/src/Fields/interpolate.jl index 06b715c129..70a7bccc17 100644 --- a/src/Fields/interpolate.jl +++ b/src/Fields/interpolate.jl @@ -398,3 +398,4 @@ function interpolate!(to_field::Field, from_field::AbstractField) return to_field end + diff --git a/test/test_reactant.jl b/test/test_reactant.jl index c9e41d4034..04f37d61ef 100644 --- a/test/test_reactant.jl +++ b/test/test_reactant.jl @@ -215,17 +215,17 @@ end test_reactant_model_correctness(RectilinearGrid, HydrostaticFreeSurfaceModel, rectilinear_kw, hydrostatic_model_kw) @info "Testing LatitudeLongitudeGrid + HydrostaticFreeSurfaceModel Reactant correctness" - hydrostatic_model_kw = (; momentum_advection=WENO()) + hydrostatic_model_kw = (; momentum_advection = WENO()) test_reactant_model_correctness(LatitudeLongitudeGrid, HydrostaticFreeSurfaceModel, lat_lon_kw, hydrostatic_model_kw) - #= + @info "Testing LatitudeLongitudeGrid + 'complicated HydrostaticFreeSurfaceModel' Reactant correctness" equation_of_state = TEOS10EquationOfState() hydrostatic_model_kw = (momentum_advection = WENOVectorInvariant(), tracer_advection = WENO(), tracers = (:T, :S, :e), buoyancy = SeawaterBuoyancy(; equation_of_state), closure = CATKEVerticalDiffusivity()) + test_reactant_model_correctness(LatitudeLongitudeGrid, HydrostaticFreeSurfaceModel, lat_lon_kw, hydrostatic_model_kw) - =# end From 50eb3de3f5c485610c2daf5198602c05108b7e2e Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Thu, 27 Feb 2025 16:21:50 -0700 Subject: [PATCH 11/19] Bugfix Reactant Simulation constructor --- ext/OceananigansReactantExt/Simulations/simulation.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/OceananigansReactantExt/Simulations/simulation.jl b/ext/OceananigansReactantExt/Simulations/simulation.jl index 8b1c0a8ca5..69fabe3fbc 100644 --- a/ext/OceananigansReactantExt/Simulations/simulation.jl +++ b/ext/OceananigansReactantExt/Simulations/simulation.jl @@ -31,6 +31,7 @@ function Simulation(model::ReactantModel; Δt, 0.0, false, false, + false, verbose, 0.0) end From f1d1fbfcef9553bc20f000da2885b6696c031983 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sat, 8 Mar 2025 11:21:19 -0700 Subject: [PATCH 12/19] Refactor simulation with Reactant, discontinue support for run! --- ext/OceananigansReactantExt/Models.jl | 5 - .../OceananigansReactantExt.jl | 4 + .../Simulations/Simulations.jl | 4 + .../Simulations/run.jl | 162 +----------------- .../Simulations/simulation.jl | 46 +---- ext/OceananigansReactantExt/TimeSteppers.jl | 21 ++- ext/OceananigansReactantExt/Utils.jl | 18 ++ .../show_hydrostatic_free_surface_model.jl | 3 +- src/TimeSteppers/clock.jl | 7 +- src/Utils/multi_region_transformation.jl | 2 +- 10 files changed, 63 insertions(+), 209 deletions(-) delete mode 100644 ext/OceananigansReactantExt/Models.jl create mode 100644 ext/OceananigansReactantExt/Utils.jl diff --git a/ext/OceananigansReactantExt/Models.jl b/ext/OceananigansReactantExt/Models.jl deleted file mode 100644 index 22399451c7..0000000000 --- a/ext/OceananigansReactantExt/Models.jl +++ /dev/null @@ -1,5 +0,0 @@ -module Models - - - -end # module diff --git a/ext/OceananigansReactantExt/OceananigansReactantExt.jl b/ext/OceananigansReactantExt/OceananigansReactantExt.jl index 7be4361644..71d0bd7cef 100644 --- a/ext/OceananigansReactantExt/OceananigansReactantExt.jl +++ b/ext/OceananigansReactantExt/OceananigansReactantExt.jl @@ -7,6 +7,9 @@ using OffsetArrays deconcretize(obj) = obj # fallback deconcretize(a::OffsetArray) = OffsetArray(Array(a.parent), a.offsets...) +include("Utils.jl") +using .Utils + include("Architectures.jl") using .Architectures @@ -46,3 +49,4 @@ constructorof(::Type{<:VectorInvariant{N, FT, M}}) where {N, FT, M} = VectorInva # using .Solvers end # module + diff --git a/ext/OceananigansReactantExt/Simulations/Simulations.jl b/ext/OceananigansReactantExt/Simulations/Simulations.jl index d604e93e1e..adfdba9ac7 100644 --- a/ext/OceananigansReactantExt/Simulations/Simulations.jl +++ b/ext/OceananigansReactantExt/Simulations/Simulations.jl @@ -22,11 +22,15 @@ using Oceananigans.Simulations: AbstractOutputWriter import Oceananigans.Simulations: + iteration, + add_callback!, Simulation, aligned_time_step, initialize!, stop_iteration_exceeded +import Oceananigans.TimeSteppers: time_step! + include("simulation.jl") include("run.jl") diff --git a/ext/OceananigansReactantExt/Simulations/run.jl b/ext/OceananigansReactantExt/Simulations/run.jl index d57cf09c90..125fc733ab 100644 --- a/ext/OceananigansReactantExt/Simulations/run.jl +++ b/ext/OceananigansReactantExt/Simulations/run.jl @@ -1,161 +1,13 @@ -using Oceananigans.Simulations: ModelCallsite -using Oceananigans: TimeStepCallsite, TendencyCallsite, UpdateStateCallsite -import Oceananigans.TimeSteppers: time_step! - -aligned_time_step(::ReactantSimulation, Δt) = Δt - -function initialize!(sim::ReactantSimulation) - #= - if sim.verbose - @info "Initializing simulation..." - start_time = time_ns() - end - =# - - model = sim.model - clock = model.clock - update_state!(model) - - #= - # Output and diagnostics initialization - [add_dependencies!(sim.diagnostics, writer) for writer in values(sim.output_writers)] - - # Initialize schedules - scheduled_activities = Iterators.flatten((values(sim.diagnostics), - values(sim.callbacks), - values(sim.output_writers))) - - for activity in scheduled_activities - initialize!(activity.schedule, sim.model) - end - =# - - #= - # Reset! the model time-stepper, evaluate all diagnostics, and write all output at first iteration - @trace if clock.iteration == 0 - reset!(timestepper(sim.model)) - - # Initialize schedules and run diagnostics, callbacks, and output writers - for diag in values(sim.diagnostics) - run_diagnostic!(diag, model) - end - - for callback in values(sim.callbacks) - callback.callsite isa TimeStepCallsite && callback(sim) - end - - for writer in values(sim.output_writers) - writer.schedule(sim.model) - write_output!(writer, model) - end - end - =# - - sim.initialized = true - - #= - if sim.verbose - initialization_time = prettytime(1e-9 * (time_ns() - start_time)) - @info " ... simulation initialization complete ($initialization_time)" - end - =# - - return nothing -end +import ..TimeSteppers: first_time_step! """ Step `sim`ulation forward by one time step. """ -function time_step!(sim::ReactantSimulation) - - #start_time_step = time_ns() - # Δt = aligned_time_step(sim, sim.Δt) - Δt = sim.Δt - - if !(sim.initialized) # execute initialization step - initialize!(sim) - initialize!(sim.model) - end - - # model_callbacks = Tuple(cb for cb in values(sim.callbacks) if cb.callsite isa ModelCallsite) - model_callbacks = tuple() - time_step!(sim.model, Δt, callbacks=model_callbacks) - - stop_sim = iteration(sim) >= sim.stop_iteration - - @trace if stop_sim - sim.running = false - else - nothing - end - - #= - for callback in values(sim.callbacks) - need_to_call = callback.schedule(sim.model) - @trace if need_to_call - callback(sim) - end - - #= - @trace if callback.callsite isa TimeStepCallsite - if callback.schedule(sim.model) - callback(sim) - else - nothing - end - else - nothing - end - =# - end - =# - - #= - # Callbacks and callback-like things - for diag in values(sim.diagnostics) - diag.schedule(sim.model) && run_diagnostic!(diag, sim.model) - end - - - for writer in values(sim.output_writers) - writer.schedule(sim.model) && write_output!(writer, sim.model) - end - - end_time_step = time_ns() - - # Increment the wall clock - sim.run_wall_time += 1e-9 * (end_time_step - start_time_step) - =# - - return nothing -end - -function run!(sim::ReactantSimulation; pickup=false) - - #= - start_run = time_ns() - - if we_want_to_pickup(pickup) - set!(sim, pickup) - end - =# - - sim.initialized = false - sim.running = true - sim.run_wall_time = 0.0 - - while sim.running - time_step!(sim) - end - - #= - for callback in values(sim.callbacks) - finalize!(callback, sim) - end - - # Increment the wall clock - end_run = time_ns() - sim.run_wall_time += 1e-9 * (end_run - start_run) - =# +initialize!(sim::ReactantSimulation) = update_state!(sim.model) +time_step!(sim::ReactantSimulation) = time_step!(sim.model, Δt; euler) +run!(sim::ReactantSimulation) = error("run! is not supported with ReactantState architecture.") +function first_time_step!(sim::ReactantSimulation) + initialize!(sim) + first_time_step!(sim.model, Δt) return nothing end diff --git a/ext/OceananigansReactantExt/Simulations/simulation.jl b/ext/OceananigansReactantExt/Simulations/simulation.jl index 69fabe3fbc..f90733ddc9 100644 --- a/ext/OceananigansReactantExt/Simulations/simulation.jl +++ b/ext/OceananigansReactantExt/Simulations/simulation.jl @@ -4,25 +4,15 @@ function Simulation(model::ReactantModel; Δt, verbose = true, stop_iteration = Inf) - Δt = validate_Δt(Δt, architecture(model)) + diagnostics = nothing + output_writers = nothing + callbacks = nothing - diagnostics = OrderedDict{Symbol, AbstractDiagnostic}() - output_writers = OrderedDict{Symbol, AbstractOutputWriter}() - callbacks = OrderedDict{Symbol, Callback}() - - callbacks[:stop_iteration_exceeded] = Callback(stop_iteration_exceeded) - - # Convert numbers to floating point; otherwise preserve type (eg for DateTime types) - # TODO: implement TT = timetype(model) and FT = eltype(model) - TT = eltype(model) - Δt = Δt isa Number ? TT(Δt) : Δt - - #stop_iteration = ConcreteRNumber(Float64(stop_iteration)) - stop_iteration = Float64(stop_iteration) + Δt = Float64(Δt) return Simulation(model, Δt, - stop_iteration, + Inf, nothing, # disallow stop_time Inf, diagnostics, @@ -36,27 +26,9 @@ function Simulation(model::ReactantModel; Δt, 0.0) end -function stop_iteration_exceeded(sim::ReactantSimulation) - #= - @trace if sim.model.clock.iteration >= sim.stop_iteration - #= - if sim.verbose - msg = string("Model iteration ", - iteration(sim), - " equals or exceeds stop iteration ", - Int(sim.stop_iteration), - ".") - - @info wall_time_msg(sim) - @info msg - end - =# - - sim.running = false - end - =# - - return nothing -end +iteration(sim::ReactantSimulation) = Reactant.to_number(iteration(sim.model)) +time(sim::ReactantSimulation) = Reactant.to_number(time(sim.model)) +add_callback!(::ReactantSimulation, args...) = + error("Cannot add callbacks to a Simulation with ReactantState architecture!") diff --git a/ext/OceananigansReactantExt/TimeSteppers.jl b/ext/OceananigansReactantExt/TimeSteppers.jl index 0727c71437..7ad2438bd8 100644 --- a/ext/OceananigansReactantExt/TimeSteppers.jl +++ b/ext/OceananigansReactantExt/TimeSteppers.jl @@ -12,12 +12,16 @@ using Oceananigans.TimeSteppers: update_state!, tick!, calculate_pressure_correction!, - correct_velocities_and_store_tendencies!, + correct_velocities_and_cache_previous_tendencies!, step_lagrangian_particles!, - QuasiAdamsBashforth2TimeStepper, - ab2_step! + QuasiAdamsBashforth2TimeStepper -import Oceananigans.TimeSteppers: Clock, unit_time, time_step! +using Oceananigans.Models.HydrostaticFreeSurfaceModels: + step_free_surface!, + local_ab2_step!, + compute_free_surface_tendency! + +import Oceananigans.TimeSteppers: Clock, unit_time, time_step!, ab2_step! const ReactantGrid{FT, TX, TY, TZ} = AbstractGrid{FT, TX, TY, TZ, <:ReactantState} where {FT, TX, TY, TZ} const ReactantModel{TS} = AbstractModel{TS, <:ReactantState} where TS @@ -26,12 +30,15 @@ function Clock(grid::ReactantGrid) FT = Float64 # may change in the future t = ConcreteRNumber(zero(FT)) iter = ConcreteRNumber(0) - stage = ConcreteRNumber(0) + stage = 0 #ConcreteRNumber(0) last_Δt = zero(FT) last_stage_Δt = zero(FT) return Clock(; time=t, iteration=iter, stage, last_Δt, last_stage_Δt) end +first_time_step!(model::ReactantModel, Δt) = time_step!(model, Δt) +first_time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt) = time_step!(model, Δt, euler=true) + function time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt; callbacks=[], euler=false) @@ -40,7 +47,6 @@ function time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt @trace if model.clock.iteration == 0 update_state!(model, callbacks; compute_tendencies=true) end - =# # Take an euler step if: # * We detect that the time-step size has changed. @@ -51,6 +57,7 @@ function time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt @trace if Δt != model.clock.last_Δt euler = true end + =# # If euler, then set χ = -0.5 minus_point_five = convert(eltype(model.grid), -0.5) @@ -67,7 +74,7 @@ function time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt model.clock.last_stage_Δt = Δt # just one stage calculate_pressure_correction!(model, Δt) - @apply_regionally correct_velocities_and_store_tendencies!(model, Δt) + correct_velocities_and_cache_previous_tendencies!(model, Δt) update_state!(model, callbacks; compute_tendencies=true) step_lagrangian_particles!(model, Δt) diff --git a/ext/OceananigansReactantExt/Utils.jl b/ext/OceananigansReactantExt/Utils.jl new file mode 100644 index 0000000000..b0b455f7ef --- /dev/null +++ b/ext/OceananigansReactantExt/Utils.jl @@ -0,0 +1,18 @@ +module Utils + +using Oceananigans +using Reactant + +import Oceananigans.Utils: prettysummary, prettytime + +function prettytime(concrete_number::ConcretePJRTNumber) + number = Reactant.to_number(concrete_number) + return prettytime(number) +end + +function prettysummary(concrete_number::ConcretePJRTNumber) + number = Reactant.to_number(concrete_number) + return string("ConcretePJRTNumber(", prettysummary(number), ")") +end + +end # module diff --git a/src/Models/HydrostaticFreeSurfaceModels/show_hydrostatic_free_surface_model.jl b/src/Models/HydrostaticFreeSurfaceModels/show_hydrostatic_free_surface_model.jl index bf461a43f1..c4ee350d1d 100644 --- a/src/Models/HydrostaticFreeSurfaceModels/show_hydrostatic_free_surface_model.jl +++ b/src/Models/HydrostaticFreeSurfaceModels/show_hydrostatic_free_surface_model.jl @@ -5,7 +5,8 @@ function Base.summary(model::HydrostaticFreeSurfaceModel) A = nameof(typeof(architecture(model.grid))) G = nameof(typeof(model.grid)) return string("HydrostaticFreeSurfaceModel{$A, $G}", - "(time = ", prettytime(model.clock.time), ", iteration = ", model.clock.iteration, ")") + "(time = ", prettytime(model.clock.time), + ", iteration = ", prettysummary(model.clock.iteration), ")") end function Base.show(io::IO, model::HydrostaticFreeSurfaceModel) diff --git a/src/TimeSteppers/clock.jl b/src/TimeSteppers/clock.jl index 66887b2953..443d8e29d2 100644 --- a/src/TimeSteppers/clock.jl +++ b/src/TimeSteppers/clock.jl @@ -13,12 +13,12 @@ Keeps track of the current `time`, `last_Δt`, `iteration` number, and time-step The `stage` is updated only for multi-stage time-stepping methods. The `time::T` is either a number or a `DateTime` object. """ -mutable struct Clock{TT, DT, IT} +mutable struct Clock{TT, DT, IT, ST} time :: TT last_Δt :: DT last_stage_Δt :: DT iteration :: IT - stage :: IT + stage :: ST end """ @@ -36,8 +36,9 @@ function Clock(; time, TT = typeof(time) DT = typeof(last_Δt) IT = typeof(iteration) + ST = typeof(stage) last_stage_Δt = convert(DT, last_Δt) - return Clock{TT, DT, IT}(time, last_Δt, last_stage_Δt, iteration, stage) + return Clock{TT, DT, IT, ST}(time, last_Δt, last_stage_Δt, iteration, stage) end # TODO: when supporting DateTime, this function will have to be extended diff --git a/src/Utils/multi_region_transformation.jl b/src/Utils/multi_region_transformation.jl index f4c2fdaddb..c84d0dfb93 100644 --- a/src/Utils/multi_region_transformation.jl +++ b/src/Utils/multi_region_transformation.jl @@ -187,7 +187,7 @@ end # TODO: The macro errors when there is a return and the function has (args...) in the -# signature (example using a macro on `multi_region_buodary_conditions:L74) +# signature (example using a macro on `multi_region_boundary_conditions:L74) """ @apply_regionally expr From 26649049f29844f2b3439d138547b1a59188620c Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sat, 8 Mar 2025 11:27:17 -0700 Subject: [PATCH 13/19] Refactor tests --- .../Simulations/simulation.jl | 2 +- test/test_reactant.jl | 23 ++++++++++++++----- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/ext/OceananigansReactantExt/Simulations/simulation.jl b/ext/OceananigansReactantExt/Simulations/simulation.jl index f90733ddc9..dfac580fc1 100644 --- a/ext/OceananigansReactantExt/Simulations/simulation.jl +++ b/ext/OceananigansReactantExt/Simulations/simulation.jl @@ -12,7 +12,7 @@ function Simulation(model::ReactantModel; Δt, return Simulation(model, Δt, - Inf, + stop_iteration, nothing, # disallow stop_time Inf, diagnostics, diff --git a/test/test_reactant.jl b/test/test_reactant.jl index babaaab4cd..05ed00dd7b 100644 --- a/test/test_reactant.jl +++ b/test/test_reactant.jl @@ -58,6 +58,19 @@ r_run!(simulation) bottom_height(x, y) = - 0.5 +function r_run!(sim, r_time_step!, r_first_time_step!) + stop_iteration = sim.stop_iteration + start_iteration = iteration(sim) + for n = start_iteration:stop_iteration + if n == 1 + r_first_time_step!(sim.model, sim.Δt) + else + r_time_step!(sim.model, sim.Δt) + end + end + return nothing +end + function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw; immersed_boundary_grid=true) r_arch = ReactantState() r_grid = GridType(r_arch; grid_kw...) @@ -113,11 +126,11 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw; r_simulation = Simulation(r_model; Δt, stop_iteration, verbose=false) @info " Compiling r_run!:" - r_run! = @compile sync=true run!(r_simulation) - # r_run! = @compile run!(r_simulation) + r_first_time_step! = @compile sync=true time_step!(r_model, Δt, euler=true) + r_time_step! = @compile sync=true time_step!(r_model, Δt) @info " Executing r_run!:" - r_run!(r_simulation) + r_run!(r_simulation, r_time_step!, r_first_time_step!) @info " After running 3 time steps, the reactant model:" @test iteration(r_simulation) == stop_iteration @@ -142,7 +155,7 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw; # Running a few more time-steps works too: r_simulation.stop_iteration += 2 - r_run!(r_simulation) + r_run!(r_simulation, r_time_step!, r_first_time_step!) @test iteration(r_simulation) == 5 @test time(r_simulation) == 5Δt @@ -160,7 +173,6 @@ end @inbounds f[i, j, k] += 1 end -#= @testset "Reactanigans unit tests" begin @info "Performing Reactanigans unit tests..." arch = ReactantState() @@ -250,7 +262,6 @@ end end end end -=# @testset "Reactant Super Simple Simulation Tests" begin @info "Performing Reactanigans super simple simulation tests..." From 9e8575d62774c06a6bb3d15091154d32dd681504 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sat, 8 Mar 2025 11:28:22 -0700 Subject: [PATCH 14/19] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e1e5e15ce6..bc08f52977 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Oceananigans" uuid = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09" authors = ["Climate Modeling Alliance and contributors"] -version = "0.95.20" +version = "0.95.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 4bfbc2c2e6468c6096576c1bef9adeb09aeee5bf Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sat, 8 Mar 2025 17:19:58 -0700 Subject: [PATCH 15/19] Fix Clock plus rm weird printing --- ext/OceananigansReactantExt/Simulations/simulation.jl | 2 +- src/TimeSteppers/clock.jl | 3 ++- src/TimeSteppers/quasi_adams_bashforth_2.jl | 2 -- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ext/OceananigansReactantExt/Simulations/simulation.jl b/ext/OceananigansReactantExt/Simulations/simulation.jl index dfac580fc1..6770911411 100644 --- a/ext/OceananigansReactantExt/Simulations/simulation.jl +++ b/ext/OceananigansReactantExt/Simulations/simulation.jl @@ -12,7 +12,7 @@ function Simulation(model::ReactantModel; Δt, return Simulation(model, Δt, - stop_iteration, + Float64(stop_iteration), nothing, # disallow stop_time Inf, diagnostics, diff --git a/src/TimeSteppers/clock.jl b/src/TimeSteppers/clock.jl index 443d8e29d2..834f9e467a 100644 --- a/src/TimeSteppers/clock.jl +++ b/src/TimeSteppers/clock.jl @@ -54,8 +54,9 @@ function Clock{TT}(; time, last_Δt = convert(DT, last_Δt) last_stage_Δt = convert(DT, last_stage_Δt) IT = typeof(iteration) + ST = typeof(stage) - return Clock{TT, DT, IT}(time, last_Δt, last_stage_Δt, iteration, stage) + return Clock{TT, DT, IT, ST}(time, last_Δt, last_stage_Δt, iteration, stage) end # helpful default diff --git a/src/TimeSteppers/quasi_adams_bashforth_2.jl b/src/TimeSteppers/quasi_adams_bashforth_2.jl index d16afff096..d02a018fd1 100644 --- a/src/TimeSteppers/quasi_adams_bashforth_2.jl +++ b/src/TimeSteppers/quasi_adams_bashforth_2.jl @@ -74,8 +74,6 @@ The steps of the Quasi-Adams-Bashforth second-order (AB2) algorithm are: function time_step!(model::AbstractModel{<:QuasiAdamsBashforth2TimeStepper}, Δt; callbacks=[], euler=false) - @info "Not using Reactantified time_step!" - Δt == 0 && @warn "Δt == 0 may cause model blowup!" # Be paranoid and update state at iteration 0 From 45195f3b1bd87f6cfcdbd03999df9740359262d9 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sat, 8 Mar 2025 23:44:11 -0700 Subject: [PATCH 16/19] Fix tests --- src/Simulations/run.jl | 1 - test/test_reactant.jl | 43 ++++++++++++++++++++++++++++-------------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/Simulations/run.jl b/src/Simulations/run.jl index faa6885b22..ac9ddba02e 100644 --- a/src/Simulations/run.jl +++ b/src/Simulations/run.jl @@ -208,7 +208,6 @@ function initialize!(sim::Simulation) model = sim.model clock = model.clock - update_state!(model) # Output and diagnostics initialization diff --git a/test/test_reactant.jl b/test/test_reactant.jl index 05ed00dd7b..5e7cf1dfe9 100644 --- a/test/test_reactant.jl +++ b/test/test_reactant.jl @@ -58,9 +58,16 @@ r_run!(simulation) bottom_height(x, y) = - 0.5 +function first_time_step!(model, Δt) + Oceananigans.initialize!(model) + Oceananigans.TimeSteppers.update_state!(model, compute_tendencies=true) + Oceananigans.TimeSteppers.time_step!(model, Δt, euler=true) + return nothing +end + function r_run!(sim, r_time_step!, r_first_time_step!) stop_iteration = sim.stop_iteration - start_iteration = iteration(sim) + start_iteration = iteration(sim) + 1 for n = start_iteration:stop_iteration if n == 1 r_first_time_step!(sim.model, sim.Δt) @@ -91,8 +98,8 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw; ui = randn(size(model.velocities.u)...) vi = randn(size(model.velocities.v)...) - set!(model, u=ui, v=ui) - set!(r_model, u=ui, v=ui) + set!(model, u=ui, v=vi) + set!(r_model, u=ui, v=vi) u, v, w = model.velocities ru, rv, rw = r_model.velocities @@ -125,12 +132,14 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw; # Reactant time now: r_simulation = Simulation(r_model; Δt, stop_iteration, verbose=false) - @info " Compiling r_run!:" - r_first_time_step! = @compile sync=true time_step!(r_model, Δt, euler=true) - r_time_step! = @compile sync=true time_step!(r_model, Δt) + @time " Compiling r_run!:" begin + r_first_time_step! = @compile sync=true first_time_step!(r_model, Δt) + r_time_step! = @compile sync=true time_step!(r_model, Δt) + end - @info " Executing r_run!:" - r_run!(r_simulation, r_time_step!, r_first_time_step!) + @time " Executing r_run!:" begin + r_run!(r_simulation, r_time_step!, r_first_time_step!) + end @info " After running 3 time steps, the reactant model:" @test iteration(r_simulation) == stop_iteration @@ -180,7 +189,6 @@ end c = CenterField(grid) @test parent(c) isa Reactant.ConcretePJRTArray - @info " Testing field set! with a number..." set!(c, 1) @test all(c .≈ 1) @@ -284,25 +292,31 @@ end @info "Testing RectilinearGrid + HydrostaticFreeSurfaceModel Reactant correctness" hydrostatic_model_kw = (; free_surface=ExplicitFreeSurface(gravitational_acceleration=1)) test_reactant_model_correctness(RectilinearGrid, HydrostaticFreeSurfaceModel, rectilinear_kw, hydrostatic_model_kw) - test_reactant_model_correctness(RectilinearGrid, HydrostaticFreeSurfaceModel, rectilinear_kw, hydrostatic_model_kw, immersed_boundary_grid=true) + + @info "Testing immersed RectilinearGrid + HydrostaticFreeSurfaceModel Reactant correctness" + test_reactant_model_correctness(RectilinearGrid, HydrostaticFreeSurfaceModel, rectilinear_kw, hydrostatic_model_kw, + immersed_boundary_grid=true) @info "Testing LatitudeLongitudeGrid + HydrostaticFreeSurfaceModel Reactant correctness" hydrostatic_model_kw = (; momentum_advection = WENO()) test_reactant_model_correctness(LatitudeLongitudeGrid, HydrostaticFreeSurfaceModel, lat_lon_kw, hydrostatic_model_kw) - test_reactant_model_correctness(LatitudeLongitudeGrid, HydrostaticFreeSurfaceModel, lat_lon_kw, hydrostatic_model_kw, immersed_boundary_grid=true) - #= + @info "Testing immersed LatitudeLongitudeGrid + HydrostaticFreeSurfaceModel Reactant correctness" + test_reactant_model_correctness(LatitudeLongitudeGrid, HydrostaticFreeSurfaceModel, lat_lon_kw, hydrostatic_model_kw, + immersed_boundary_grid=true) + # This test takes too long @info "Testing LatitudeLongitudeGrid + SplitExplicitFreeSurface + HydrostaticFreeSurfaceModel Reactant correctness" hydrostatic_model_kw = (; momentum_advection=WENOVectorInvariant(), free_surface=SplitExplicitFreeSurface(substeps=4)) test_reactant_model_correctness(LatitudeLongitudeGrid, HydrostaticFreeSurfaceModel, lat_lon_kw, hydrostatic_model_kw) - simulation = test_reactant_model_correctness(LatitudeLongitudeGrid, HydrostaticFreeSurfaceModel, lat_lon_kw, hydrostatic_model_kw, immersed_boundary_grid=true) + simulation = test_reactant_model_correctness(LatitudeLongitudeGrid, HydrostaticFreeSurfaceModel, lat_lon_kw, + hydrostatic_model_kw, immersed_boundary_grid=true) η = simulation.model.free_surface.η η_grid = η.grid @test isnothing(η_grid.interior_active_cells) @test isnothing(η_grid.active_z_columns) - =# + #= @info "Testing LatitudeLongitudeGrid + 'complicated HydrostaticFreeSurfaceModel' Reactant correctness" equation_of_state = TEOS10EquationOfState() hydrostatic_model_kw = (momentum_advection = WENOVectorInvariant(), @@ -312,5 +326,6 @@ end closure = CATKEVerticalDiffusivity()) test_reactant_model_correctness(LatitudeLongitudeGrid, HydrostaticFreeSurfaceModel, lat_lon_kw, hydrostatic_model_kw) + =# end From eb52bf350d28281b9acf696d19c4a594cebdbe8a Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sun, 9 Mar 2025 00:04:22 -0700 Subject: [PATCH 17/19] Use first_time_step in tests --- .../Simulations/Simulations.jl | 2 ++ ext/OceananigansReactantExt/Simulations/run.jl | 10 ++++++++-- ext/OceananigansReactantExt/TimeSteppers.jl | 16 ++++++++++++++-- src/Simulations/run.jl | 13 +++++-------- test/test_reactant.jl | 9 +++++++-- 5 files changed, 36 insertions(+), 14 deletions(-) diff --git a/ext/OceananigansReactantExt/Simulations/Simulations.jl b/ext/OceananigansReactantExt/Simulations/Simulations.jl index adfdba9ac7..adfbb1ee23 100644 --- a/ext/OceananigansReactantExt/Simulations/Simulations.jl +++ b/ext/OceananigansReactantExt/Simulations/Simulations.jl @@ -1,5 +1,7 @@ module Simulations +export first_time_step!, time_step_for! + using Reactant using Oceananigans diff --git a/ext/OceananigansReactantExt/Simulations/run.jl b/ext/OceananigansReactantExt/Simulations/run.jl index 125fc733ab..ff9a0885af 100644 --- a/ext/OceananigansReactantExt/Simulations/run.jl +++ b/ext/OceananigansReactantExt/Simulations/run.jl @@ -1,13 +1,19 @@ import ..TimeSteppers: first_time_step! """ Step `sim`ulation forward by one time step. """ -initialize!(sim::ReactantSimulation) = update_state!(sim.model) time_step!(sim::ReactantSimulation) = time_step!(sim.model, Δt; euler) run!(sim::ReactantSimulation) = error("run! is not supported with ReactantState architecture.") function first_time_step!(sim::ReactantSimulation) initialize!(sim) - first_time_step!(sim.model, Δt) + first_time_step!(sim.model, sim.Δt) + return nothing +end + +function time_step_for!(sim::ReactantSimulation, Nsteps) + @trace for _ = 1:Nsteps + time_step!(sim) + end return nothing end diff --git a/ext/OceananigansReactantExt/TimeSteppers.jl b/ext/OceananigansReactantExt/TimeSteppers.jl index 7ad2438bd8..26f05d650d 100644 --- a/ext/OceananigansReactantExt/TimeSteppers.jl +++ b/ext/OceananigansReactantExt/TimeSteppers.jl @@ -22,6 +22,7 @@ using Oceananigans.Models.HydrostaticFreeSurfaceModels: compute_free_surface_tendency! import Oceananigans.TimeSteppers: Clock, unit_time, time_step!, ab2_step! +import Oceananigans: initialize! const ReactantGrid{FT, TX, TY, TZ} = AbstractGrid{FT, TX, TY, TZ, <:ReactantState} where {FT, TX, TY, TZ} const ReactantModel{TS} = AbstractModel{TS, <:ReactantState} where TS @@ -36,8 +37,19 @@ function Clock(grid::ReactantGrid) return Clock(; time=t, iteration=iter, stage, last_Δt, last_stage_Δt) end -first_time_step!(model::ReactantModel, Δt) = time_step!(model, Δt) -first_time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt) = time_step!(model, Δt, euler=true) +function first_time_step!(model::ReactantModel, Δt) + initialize!(model) + update_state!(model) + time_step!(model, Δt) + return nothing +end + +function first_time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt) + initialize!(model) + update_state!(model) + time_step!(model, Δt, euler=true) + return nothing +end function time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt; callbacks=[], euler=false) diff --git a/src/Simulations/run.jl b/src/Simulations/run.jl index ac9ddba02e..f2ecb454a8 100644 --- a/src/Simulations/run.jl +++ b/src/Simulations/run.jl @@ -135,10 +135,7 @@ function time_step!(sim::Simulation) end initial_time_step = !(sim.initialized) - if initial_time_step # execute initialization step - initialize!(sim) - initialize!(sim.model) - end + initial_time_step && initialize!(sim) if initial_time_step && sim.verbose @info "Executing initial time step..." @@ -207,7 +204,7 @@ function initialize!(sim::Simulation) end model = sim.model - clock = model.clock + initialize!(model) update_state!(model) # Output and diagnostics initialization @@ -223,8 +220,8 @@ function initialize!(sim::Simulation) end # Reset! the model time-stepper, evaluate all diagnostics, and write all output at first iteration - if clock.iteration == 0 - reset!(timestepper(sim.model)) + if model.clock.iteration == 0 + reset!(timestepper(model)) # Initialize schedules and run diagnostics, callbacks, and output writers for diag in values(sim.diagnostics) @@ -236,7 +233,7 @@ function initialize!(sim::Simulation) end for writer in values(sim.output_writers) - writer.schedule(sim.model) + writer.schedule(model) write_output!(writer, model) end end diff --git a/test/test_reactant.jl b/test/test_reactant.jl index 5e7cf1dfe9..6a579d47bc 100644 --- a/test/test_reactant.jl +++ b/test/test_reactant.jl @@ -132,13 +132,17 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw; # Reactant time now: r_simulation = Simulation(r_model; Δt, stop_iteration, verbose=false) + Nsteps = ConcretePJRTNumber(3) @time " Compiling r_run!:" begin - r_first_time_step! = @compile sync=true first_time_step!(r_model, Δt) - r_time_step! = @compile sync=true time_step!(r_model, Δt) + r_first_time_step! = @compile sync=true OceananigansReactantExt.first_time_step!(r_model, Δt) + r_time_step! = @compile sync=true Oceananigans.TimeSteppers.time_step!(r_model, Δt) + #r_time_step_for! = @compile sync=true OceananigansReactantExt.time_step_for!(r_simulation, Nsteps) end @time " Executing r_run!:" begin r_run!(r_simulation, r_time_step!, r_first_time_step!) + #r_first_time_step!(r_simulation) + #r_time_step_for!(r_simulation, 2) end @info " After running 3 time steps, the reactant model:" @@ -165,6 +169,7 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw; # Running a few more time-steps works too: r_simulation.stop_iteration += 2 r_run!(r_simulation, r_time_step!, r_first_time_step!) + #r_time_step_for!(r_simulation, 2) @test iteration(r_simulation) == 5 @test time(r_simulation) == 5Δt From 25c232f8a455b8f19333ffd3f844dea289d21605 Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sun, 9 Mar 2025 08:25:09 -0600 Subject: [PATCH 18/19] Transfer first_time_step to Oceananigans --- ext/OceananigansReactantExt/TimeSteppers.jl | 14 -------------- src/TimeSteppers/TimeSteppers.jl | 16 +++++++++++++++- test/test_reactant.jl | 9 +-------- 3 files changed, 16 insertions(+), 23 deletions(-) diff --git a/ext/OceananigansReactantExt/TimeSteppers.jl b/ext/OceananigansReactantExt/TimeSteppers.jl index 26f05d650d..1363065191 100644 --- a/ext/OceananigansReactantExt/TimeSteppers.jl +++ b/ext/OceananigansReactantExt/TimeSteppers.jl @@ -37,20 +37,6 @@ function Clock(grid::ReactantGrid) return Clock(; time=t, iteration=iter, stage, last_Δt, last_stage_Δt) end -function first_time_step!(model::ReactantModel, Δt) - initialize!(model) - update_state!(model) - time_step!(model, Δt) - return nothing -end - -function first_time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt) - initialize!(model) - update_state!(model) - time_step!(model, Δt, euler=true) - return nothing -end - function time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper}, Δt; callbacks=[], euler=false) diff --git a/src/TimeSteppers/TimeSteppers.jl b/src/TimeSteppers/TimeSteppers.jl index 6513186991..ae29bf91f2 100644 --- a/src/TimeSteppers/TimeSteppers.jl +++ b/src/TimeSteppers/TimeSteppers.jl @@ -9,7 +9,7 @@ export using CUDA using KernelAbstractions -using Oceananigans: AbstractModel, prognostic_fields +using Oceananigans: AbstractModel, initialize!, prognostic_fields using Oceananigans.Architectures: device using Oceananigans.Utils: work_layout @@ -66,4 +66,18 @@ TimeStepper(::Val{:RungeKutta3}, args...; kwargs...) = TimeStepper(::Val{:SplitRungeKutta3}, args...; kwargs...) = SplitRungeKutta3TimeStepper(args...; kwargs...) +function first_time_step!(model::AbstractModel, Δt) + initialize!(model) + update_state!(model) + time_step!(model, Δt) + return nothing +end + +function first_time_step!(model::AbstractModel{<:QuasiAdamsBashforth2TimeStepper}, Δt) + initialize!(model) + update_state!(model) + time_step!(model, Δt, euler=true) + return nothing +end + end # module diff --git a/test/test_reactant.jl b/test/test_reactant.jl index 6a579d47bc..a668350cd7 100644 --- a/test/test_reactant.jl +++ b/test/test_reactant.jl @@ -58,13 +58,6 @@ r_run!(simulation) bottom_height(x, y) = - 0.5 -function first_time_step!(model, Δt) - Oceananigans.initialize!(model) - Oceananigans.TimeSteppers.update_state!(model, compute_tendencies=true) - Oceananigans.TimeSteppers.time_step!(model, Δt, euler=true) - return nothing -end - function r_run!(sim, r_time_step!, r_first_time_step!) stop_iteration = sim.stop_iteration start_iteration = iteration(sim) + 1 @@ -134,7 +127,7 @@ function test_reactant_model_correctness(GridType, ModelType, grid_kw, model_kw; Nsteps = ConcretePJRTNumber(3) @time " Compiling r_run!:" begin - r_first_time_step! = @compile sync=true OceananigansReactantExt.first_time_step!(r_model, Δt) + r_first_time_step! = @compile sync=true Oceananigans.TimeSteppers.first_time_step!(r_model, Δt) r_time_step! = @compile sync=true Oceananigans.TimeSteppers.time_step!(r_model, Δt) #r_time_step_for! = @compile sync=true OceananigansReactantExt.time_step_for!(r_simulation, Nsteps) end From d7da2e16797145373c745468bdb7c852192f8cad Mon Sep 17 00:00:00 2001 From: Gregory Wagner Date: Sun, 9 Mar 2025 08:50:12 -0600 Subject: [PATCH 19/19] Get enzyme tests to pass --- src/TimeSteppers/clock.jl | 19 +++++++++++++------ test/test_enzyme.jl | 9 +++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/TimeSteppers/clock.jl b/src/TimeSteppers/clock.jl index 834f9e467a..d7f67b696f 100644 --- a/src/TimeSteppers/clock.jl +++ b/src/TimeSteppers/clock.jl @@ -13,12 +13,21 @@ Keeps track of the current `time`, `last_Δt`, `iteration` number, and time-step The `stage` is updated only for multi-stage time-stepping methods. The `time::T` is either a number or a `DateTime` object. """ -mutable struct Clock{TT, DT, IT, ST} +mutable struct Clock{TT, DT, IT} time :: TT last_Δt :: DT last_stage_Δt :: DT iteration :: IT - stage :: ST + stage :: Int +end + +function reset!(clock::Clock{TT, DT, IT}) where {TT, DT, IT} + clock.time = zero(TT) + clock.iteration = zero(IT) + clock.stage = 0 + clock.last_Δt = Inf + clock.last_stage_Δt = Inf + return nothing end """ @@ -36,9 +45,8 @@ function Clock(; time, TT = typeof(time) DT = typeof(last_Δt) IT = typeof(iteration) - ST = typeof(stage) last_stage_Δt = convert(DT, last_Δt) - return Clock{TT, DT, IT, ST}(time, last_Δt, last_stage_Δt, iteration, stage) + return Clock{TT, DT, IT}(time, last_Δt, last_stage_Δt, iteration, stage) end # TODO: when supporting DateTime, this function will have to be extended @@ -54,9 +62,8 @@ function Clock{TT}(; time, last_Δt = convert(DT, last_Δt) last_stage_Δt = convert(DT, last_stage_Δt) IT = typeof(iteration) - ST = typeof(stage) - return Clock{TT, DT, IT, ST}(time, last_Δt, last_stage_Δt, iteration, stage) + return Clock{TT, DT, IT}(time, last_Δt, last_stage_Δt, iteration, stage) end # helpful default diff --git a/test/test_enzyme.jl b/test/test_enzyme.jl index 732caf746a..5f2877f68a 100644 --- a/test/test_enzyme.jl +++ b/test/test_enzyme.jl @@ -1,6 +1,7 @@ include("dependencies_for_runtests.jl") using Enzyme +using Oceananigans.TimeSteppers: reset! # Required presently Enzyme.API.looseTypeAnalysis!(true) @@ -36,6 +37,7 @@ function set_initial_condition!(model, amplitude) end function stable_diffusion!(model, amplitude, diffusivity) + reset!(model.clock) set_diffusivity!(model, diffusivity) set_initial_condition!(model, amplitude) @@ -45,9 +47,6 @@ function stable_diffusion!(model, amplitude, diffusivity) Δz = 1 / Nz Δt = 1e-1 * Δz^2 / κ_max - model.clock.time = 0 - model.clock.iteration = 0 - for _ = 1:10 time_step!(model, Δt; euler=true) end @@ -291,9 +290,7 @@ end function viscous_hydrostatic_turbulence(ν, model, u_init, v_init, Δt, u_truth, v_truth) # Initialize the model - model.clock.iteration = 0 - model.clock.time = 0 - model.clock.last_Δt = Inf + reset!(model.clock) set_viscosity!(model, ν) set!(model, u=u_init, v=v_init) fill!(model.free_surface.η, 0)