Skip to content

Commit

Permalink
Add support for async f and j
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Mar 25, 2024
1 parent f2e2b71 commit b29df31
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 0 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
CudaExt = "CUDA"

[compat]
ClimaComms = "0.4, 0.5"
Colors = "0.12"
Expand Down
60 changes: 60 additions & 0 deletions ext/CudaExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module CudaExt

import CUDA
import ClimaComms: SingletonCommsContext, CUDADevice
import ClimaTimeSteppers: compute_fj!

@inline function compute_fj!(f, j, U, f!, j!, ::SingletonCommsContext{CUDADevice})
# TODO: we should benchmark these two options to
# see if one is preferrable over the other
if Base.Threads.nthreads() > 1
compute_fj_spawn!(f, j, U, f!, j!)
else
compute_fj_streams!(f, j, U, f!, j!)
end
end

@inline function compute_fj_streams!(f, j, U, f!, j!)
event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
CUDA.record(event, CUDA.stream()) # record event on main stream

stream1 = CUDA.CuStream() # make a stream
local event1
CUDA.stream!(stream1) do # work to be done by stream1
CUDA.wait(event, stream1) # make stream1 wait on event (host continues)
f!(f, U)
event1 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
end
CUDA.record(event1, stream1) # record event1 on stream1

stream2 = CUDA.CuStream() # make a stream
local event2
CUDA.stream!(stream2) do # work to be done by stream2
CUDA.wait(event, stream2) # make stream2 wait on event (host continues)
j!(j, U)
event2 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
end
CUDA.record(event2, stream2) # record event2 on stream2

CUDA.wait(event1, CUDA.stream()) # make main stream wait on event1
CUDA.wait(event2, CUDA.stream()) # make main stream wait on event2
end

@inline function compute_fj_spawn!(f, j, U, f!, j!)

CUDA.synchronize()
CUDA.@sync begin
Base.Threads.@spawn begin
f!(f, U)
CUDA.synchronize()
nothing
end
Base.Threads.@spawn begin
j!(j, U)
CUDA.synchronize()
nothing
end
end
end

end
17 changes: 17 additions & 0 deletions src/solvers/compute_fj.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@inline function compute_fj!(f, j, U, f!, j!, ::Union{Nothing, ClimaComms.AbstractCommsContext})
f!(f, U)
j!(j, U)
end

@inline function compute_fj!(f, j, U, f!, j!, ::ClimaComms.SingletonCommsContext{ClimaComms.CPUMultiThreaded})
Base.@sync begin
Base.Threads.@spawn begin
f!(f, U)
nothing
end
Base.Threads.@spawn begin
j!(j, U)
nothing
end
end
end

0 comments on commit b29df31

Please sign in to comment.