diff --git a/docs/src/index.md b/docs/src/index.md index bf0e7db3..88dcc8dc 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -21,6 +21,7 @@ ClimaComms.array_type ClimaComms.@threaded ClimaComms.@time ClimaComms.@elapsed +ClimaComms.@sync ``` ## Contexts diff --git a/src/devices.jl b/src/devices.jl index 1133b483..a994bb0f 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -163,3 +163,34 @@ macro elapsed(device, expr) end end end + +""" + @sync device expr + +Device-flexible `@sync`. + +Lowers to +```julia +@sync expr +``` +for CPU devices and +```julia +CUDA.@sync expr +``` +for CUDA devices. +""" +macro sync(device, expr) + # https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207 + return esc(quote + if $(device) isa $CUDADevice + $CUDA.@sync begin + $(expr) + end + else + @assert $(device) isa $AbstractDevice + $Base.@sync begin + $(expr) + end + end + end) +end diff --git a/test/hygiene.jl b/test/hygiene.jl index 6465e2b2..b1a794e2 100644 --- a/test/hygiene.jl +++ b/test/hygiene.jl @@ -14,6 +14,10 @@ function test_macro_hyhiene(dev) CC.@elapsed dev for i in 1:n sin.(rand(10)) end + + CC.@sync dev for i in 1:n + sin.(rand(10)) + end end dev = CC.device()