Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for async f! and j! #229

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
CudaExt = "CUDA"

[compat]
ClimaComms = "0.4, 0.5"
Colors = "0.12"
Expand Down
60 changes: 60 additions & 0 deletions ext/CudaExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module CudaExt

import CUDA
import ClimaComms: SingletonCommsContext, CUDADevice
import ClimaTimeSteppers: compute_fj!

@inline function compute_fj!(f, j, U, f!, j!, ::SingletonCommsContext{CUDADevice})
# TODO: we should benchmark these two options to
# see if one is preferrable over the other
if Base.Threads.nthreads() > 1
compute_fj_spawn!(f, j, U, f!, j!)
else
compute_fj_streams!(f, j, U, f!, j!)
end
end

@inline function compute_fj_streams!(f, j, U, f!, j!)
event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
CUDA.record(event, CUDA.stream()) # record event on main stream

stream1 = CUDA.CuStream() # make a stream
local event1
CUDA.stream!(stream1) do # work to be done by stream1
CUDA.wait(event, stream1) # make stream1 wait on event (host continues)
f!(f, U)
event1 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
end
CUDA.record(event1, stream1) # record event1 on stream1

stream2 = CUDA.CuStream() # make a stream
local event2
CUDA.stream!(stream2) do # work to be done by stream2
CUDA.wait(event, stream2) # make stream2 wait on event (host continues)
j!(j, U)
event2 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
end
CUDA.record(event2, stream2) # record event2 on stream2

CUDA.wait(event1, CUDA.stream()) # make main stream wait on event1
CUDA.wait(event2, CUDA.stream()) # make main stream wait on event2
end

@inline function compute_fj_spawn!(f, j, U, f!, j!)

CUDA.synchronize()
CUDA.@sync begin
Base.Threads.@spawn begin
f!(f, U)
CUDA.synchronize()
nothing
end
Base.Threads.@spawn begin
j!(j, U)
CUDA.synchronize()
nothing
end
end
end

end
17 changes: 17 additions & 0 deletions src/solvers/compute_fj.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@inline function compute_fj!(f, j, U, f!, j!, ::Union{Nothing, ClimaComms.AbstractCommsContext})
f!(f, U)
j!(j, U)
end

@inline function compute_fj!(f, j, U, f!, j!, ::ClimaComms.SingletonCommsContext{ClimaComms.CPUMultiThreaded})
Base.@sync begin
Base.Threads.@spawn begin
f!(f, U)
nothing
end
Base.Threads.@spawn begin
j!(j, U)
nothing
end
end
end
Loading