diff --git a/Project.toml b/Project.toml index f670d14e..e8ec6eea 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,12 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +[weakdeps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[extensions] +CudaExt = "CUDA" + [compat] ClimaComms = "0.4, 0.5" Colors = "0.12" diff --git a/ext/CudaExt.jl b/ext/CudaExt.jl new file mode 100644 index 00000000..af2c3d4e --- /dev/null +++ b/ext/CudaExt.jl @@ -0,0 +1,60 @@ +module CudaExt + +import CUDA +import ClimaComms: SingletonCommsContext, CUDADevice +import ClimaTimeSteppers: compute_fj! + +@inline function compute_fj!(f, j, U, f!, j!, ::SingletonCommsContext{CUDADevice}) + # TODO: we should benchmark these two options to + # see if one is preferrable over the other + if Base.Threads.nthreads() > 1 + compute_fj_spawn!(f, j, U, f!, j!) + else + compute_fj_streams!(f, j, U, f!, j!) + end +end + +@inline function compute_fj_streams!(f, j, U, f!, j!) + event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) + CUDA.record(event, CUDA.stream()) # record event on main stream + + stream1 = CUDA.CuStream() # make a stream + local event1 + CUDA.stream!(stream1) do # work to be done by stream1 + CUDA.wait(event, stream1) # make stream1 wait on event (host continues) + f!(f, U) + event1 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) + end + CUDA.record(event1, stream1) # record event1 on stream1 + + stream2 = CUDA.CuStream() # make a stream + local event2 + CUDA.stream!(stream2) do # work to be done by stream2 + CUDA.wait(event, stream2) # make stream2 wait on event (host continues) + j!(j, U) + event2 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) + end + CUDA.record(event2, stream2) # record event2 on stream2 + + CUDA.wait(event1, CUDA.stream()) # make main stream wait on event1 + CUDA.wait(event2, CUDA.stream()) # make main stream wait on event2 +end + +@inline function compute_fj_spawn!(f, j, U, f!, j!) + + CUDA.synchronize() + CUDA.@sync begin + Base.Threads.@spawn begin + f!(f, U) + CUDA.synchronize() + nothing + end + Base.Threads.@spawn begin + j!(j, U) + CUDA.synchronize() + nothing + end + end +end + +end diff --git a/src/solvers/compute_fj.jl b/src/solvers/compute_fj.jl new file mode 100644 index 00000000..de2fe5ca --- /dev/null +++ b/src/solvers/compute_fj.jl @@ -0,0 +1,17 @@ +@inline function compute_fj!(f, j, U, f!, j!, ::Union{Nothing, ClimaComms.AbstractCommsContext}) + f!(f, U) + j!(j, U) +end + +@inline function compute_fj!(f, j, U, f!, j!, ::ClimaComms.SingletonCommsContext{ClimaComms.CPUMultiThreaded}) + Base.@sync begin + Base.Threads.@spawn begin + f!(f, U) + nothing + end + Base.Threads.@spawn begin + j!(j, U) + nothing + end + end +end