Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set a default Clock depending on the grid #4096

Merged
merged 29 commits into from
Mar 9, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5b5e893
Set a default Clock depending on the grid
glwagner Feb 12, 2025
c12d044
Test default clock
glwagner Feb 13, 2025
fb38e66
Fix up the implementation of the extension and clean up tests
glwagner Feb 16, 2025
ba16173
Some major extensions
glwagner Feb 16, 2025
534cb96
Merge branch 'main' into glw/arch-dep-clock
glwagner Feb 16, 2025
ff4480d
Extend initialize!
glwagner Feb 17, 2025
6cf14a2
Working through Reactant errors
glwagner Feb 18, 2025
2be1f74
Merge branch 'main' into glw/arch-dep-clock
glwagner Feb 21, 2025
037d15f
Trying a few things plus extend stop_iteration_exceeded
glwagner Feb 22, 2025
ab070dd
Merge branch 'main' into glw/arch-dep-clock
glwagner Feb 22, 2025
811dca5
Updates
glwagner Feb 24, 2025
912db0e
Add test code
glwagner Feb 26, 2025
70effaa
Finally get test_reactant to copmile for a long time
glwagner Feb 27, 2025
2e1fa09
Merge remote-tracking branch 'origin/main' into glw/arch-dep-clock
glwagner Feb 27, 2025
50eb3de
Bugfix Reactant Simulation constructor
glwagner Feb 27, 2025
002c6a5
Merge branch 'main' into glw/arch-dep-clock
glwagner Mar 4, 2025
7173c28
Merge remote-tracking branch 'origin/main' into glw/arch-dep-clock
glwagner Mar 7, 2025
146dd7b
Merge remote-tracking branch 'origin/main' into glw/arch-dep-clock
glwagner Mar 7, 2025
f1d1fbf
Refactor simulation with Reactant, discontinue support for run!
glwagner Mar 8, 2025
2664904
Refactor tests
glwagner Mar 8, 2025
9e8575d
Bump version
glwagner Mar 8, 2025
3bcd95d
Merge branch 'main' into glw/arch-dep-clock
glwagner Mar 8, 2025
4bfbc2c
Fix Clock plus rm weird printing
glwagner Mar 9, 2025
edccb0f
Merge branch 'main' into glw/arch-dep-clock
glwagner Mar 9, 2025
45195f3
Fix tests
glwagner Mar 9, 2025
b8bf002
Merge branch 'glw/arch-dep-clock' of https://github.com/CliMA/Oceanan…
glwagner Mar 9, 2025
eb52bf3
Use first_time_step in tests
glwagner Mar 9, 2025
25c232f
Transfer first_time_step to Oceananigans
glwagner Mar 9, 2025
d7da2e1
Get enzyme tests to pass
glwagner Mar 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ext/OceananigansReactantExt/Architectures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ module Architectures
using Reactant
using Oceananigans

import Oceananigans.Architectures: device, architecture, array_type, on_architecture, unified_array, ReactantState, device_copy_to!
import Oceananigans.Architectures: device, architecture, array_type, on_architecture
import Oceananigans.Architectures: unified_array, ReactantState, device_copy_to!

const ReactantKernelAbstractionsExt = Base.get_extension(
Reactant, :ReactantKernelAbstractionsExt
)

const ReactantBackend = ReactantKernelAbstractionsExt.ReactantBackend

device(::ReactantState) = ReactantBackend()

architecture(::Reactant.AnyConcreteRArray) = ReactantState
Expand All @@ -24,6 +27,7 @@ on_architecture(::ReactantState, a::SubArray{<:Any, <:Any, <:Array}) = ConcreteR

unified_array(::ReactantState, a) = a

@inline device_copy_to!(dst::Reactant.AnyConcreteRArray, src::Reactant.AnyConcreteRArray; kw...) = Base.copyto!(dst, src)
@inline device_copy_to!(dst::Reactant.AnyConcreteRArray, src::Reactant.AnyConcreteRArray; kw...) =
Base.copyto!(dst, src)

end # module
1 change: 1 addition & 0 deletions ext/OceananigansReactantExt/OceananigansReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Reactant
include("Architectures.jl")
using .Architectures


# These are additional modules that may need to be Reactantified in the future:
#
# include("Utils.jl")
Expand Down
24 changes: 24 additions & 0 deletions ext/OceananigansReactantExt/TimeSteppers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module TimeSteppers

using Reactant
using Oceananigans

import Oceananigans.Grids: AbstractGrid
import Oceananigans.TimeSteppers: Clock

using OceananigansReactantExt: ReactantState

const ReactantGrid{FT, TX, TY, TZ} = AbstractGrid{FT, TX, TY, TZ, <:ReactantState} where {FT, TX, TY, TZ}

function Clock(grid::ReactantGrid)
FT = Float64 # may change in the future
t = ConcreteRNumber(zero(FT))
iter = ConcreteRNumber(0)
stage = ConcreteRNumber(0)
last_Δt = ConcreteRNumber(zero(FT))
last_stage_Δt = ConcreteRNumber(zero(FT))
return Clock(; time=t, iteration=iter, stage, last_Δt, last_stage_Δt)
end

end # module

Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Keyword arguments
- `vertical_coordinate`: Rulesets that define the time-evolution of the grid (ZStar/ZCoordinate). Default: `ZCoordinate()`.
"""
function HydrostaticFreeSurfaceModel(; grid,
clock = Clock{Float64}(time = 0),
clock = Clock(grid),
momentum_advection = VectorInvariant(),
tracer_advection = Centered(),
buoyancy = nothing,
Expand Down
2 changes: 1 addition & 1 deletion src/Models/NonhydrostaticModels/nonhydrostatic_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ Keyword arguments
- `auxiliary_fields`: `NamedTuple` of auxiliary fields. Default: `nothing`
"""
function NonhydrostaticModel(; grid,
clock = Clock{Float64}(time = 0),
clock = Clock(grid),
advection = Centered(),
buoyancy = nothing,
coriolis = nothing,
Expand Down
2 changes: 1 addition & 1 deletion src/Models/ShallowWaterModels/shallow_water_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ Keyword arguments
function ShallowWaterModel(;
grid,
gravitational_acceleration,
clock = Clock{eltype(grid)}(time=0),
clock = Clock(grid),
momentum_advection = UpwindBiased(order=5),
tracer_advection = WENO(),
mass_advection = WENO(),
Expand Down
4 changes: 4 additions & 0 deletions src/TimeSteppers/clock.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Adapt
using Dates: AbstractTime, DateTime, Nanosecond, Millisecond
using Oceananigans.Utils: prettytime
using Oceananigans.Grids: AbstractGrid

import Base: show
import Oceananigans.Units: Time
Expand Down Expand Up @@ -54,6 +55,9 @@ function Clock{TT}(; time,
return Clock{TT, DT}(time, last_Δt, last_stage_Δt, iteration, stage)
end

# helpful default
Clock(grid::AbstractGrid) = Clock{Float64}(time=0)

function Base.summary(clock::Clock)
TT = typeof(clock.time)
DT = typeof(clock.last_Δt)
Expand Down