diff --git a/Project.toml b/Project.toml index e98f541c..c2695492 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,7 @@ authors = [ "Jake Bolewski ", "Gabriele Bozzola ", ] -version = "0.6.1" +version = "0.6.2" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/ext/ClimaCommsCUDAExt.jl b/ext/ClimaCommsCUDAExt.jl index fff76b94..171583ae 100644 --- a/ext/ClimaCommsCUDAExt.jl +++ b/ext/ClimaCommsCUDAExt.jl @@ -3,23 +3,27 @@ module ClimaCommsCUDAExt import CUDA import ClimaComms +import ClimaComms: CUDADevice -function ClimaComms._assign_device(::ClimaComms.CUDADevice, rank_number) +function ClimaComms._assign_device(::CUDADevice, rank_number) CUDA.device!(rank_number % CUDA.ndevices()) return nothing end -function ClimaComms.device_functional(::ClimaComms.CUDADevice) +function ClimaComms.device_functional(::CUDADevice) return CUDA.functional() end -ClimaComms.array_type(::ClimaComms.CUDADevice) = CUDA.CuArray -ClimaComms.allowscalar(f, ::ClimaComms.CUDADevice, args...; kwargs...) = +ClimaComms.array_type(::CUDADevice) = CUDA.CuArray +ClimaComms.allowscalar(f, ::CUDADevice, args...; kwargs...) = CUDA.@allowscalar f(args...; kwargs...) # Extending ClimaComms methods that operate on expressions (cannot use dispatch here) -ClimaComms.cuda_sync(expr) = CUDA.@sync expr -ClimaComms.cuda_time(expr) = CUDA.@time expr -ClimaComms.cuda_elasped(expr) = CUDA.@elapsed expr +ClimaComms.sync(f::F, ::CUDADevice, args...; kwargs...) where {F} = + CUDA.@sync f(args...; kwargs...) +ClimaComms.time(f::F, ::CUDADevice, args...; kwargs...) where {F} = + CUDA.@time f(args...; kwargs...) +ClimaComms.elapsed(f::F, ::CUDADevice, args...; kwargs...) where {F} = + CUDA.@elapsed f(args...; kwargs...) end diff --git a/src/devices.jl b/src/devices.jl index bd77d089..c216d05a 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -129,66 +129,151 @@ macro threaded(device, loop) end end -function cuda_time end +""" + @time f(args...; kwargs...) + +Device-flexible `@time`: + +Calls +```julia +@time f(args...; kwargs...) +``` +for CPU devices and +```julia +CUDA.@time f(args...; kwargs...) +``` +for CUDA devices. +""" +function time(f::F, device::AbstractDevice, args...; kwargs...) where {F} + Base.@time begin + f(args...; kwargs...) + end +end """ - @time device expr + elapsed(f::F, device::AbstractDevice, args...; kwargs...) -Device-flexible `@time`. +Device-flexible `elapsed`. -Lowers to +Calls ```julia -@time expr +@elapsed f(args...; kwargs...) ``` for CPU devices and ```julia -CUDA.@time expr +CUDA.@elapsed f(args...; kwargs...) ``` for CUDA devices. """ -macro time(device, expr) - __CC__ = ClimaComms - return esc(quote - if $device isa $CUDADevice - $(__CC__).cuda_time($expr) - else - @assert $device isa $AbstractDevice - $Base.@time $(expr) +function elapsed(f::F, device::AbstractDevice, args...; kwargs...) where {F} + Base.@elapsed begin + f(args...; kwargs...) + end +end + +""" + sync(f, ::AbstractDevice, args...; kwargs...) + +Device-flexible function that calls `@sync`. + +Calls +```julia +@sync f(args...; kwargs...) +``` +for CPU devices and +```julia +CUDA.@sync f(args...; kwargs...) +``` +for CUDA devices. + +An example use-case of this might be: +```julia +BenchmarkTools.@benchmark begin + if ClimaComms.device() isa ClimaComms.CUDADevice + CUDA.@sync begin + launch_cuda_kernels_or_spawn_tasks!(...) end - end) + elseif ClimaComms.device() isa ClimaComms.CPUMultiThreading + Base.@sync begin + launch_cuda_kernels_or_spawn_tasks!(...) + end + end end +``` -function cuda_elasped end +If the CPU version of the above example does not leverage +spawned tasks (which require using `Base.sync` or `Threads.wait` +to synchronize), then you may want to simply use [`cuda_sync`](@ref). +""" +function sync(f::F, ::AbstractDevice, args...; kwargs...) where {F} + Base.@sync begin + f(args...; kwargs...) + end +end """ - @elapsed device expr + cuda_sync(f, ::AbstractDevice, args...; kwargs...) -Device-flexible `@elapsed`. +Device-flexible function that (may) call `CUDA.@sync`. + +Calls +```julia +f(args...; kwargs...) +``` +for CPU devices and +```julia +CUDA.@sync f(args...; kwargs...) +``` +for CUDA devices. +""" +function cuda_sync(f::F, ::AbstractDevice, args...; kwargs...) where {F} + f(args...; kwargs...) +end + +""" + allowscalar(f, ::AbstractDevice, args...; kwargs...) + +Device-flexible version of `CUDA.@allowscalar`. Lowers to ```julia -@elapsed expr +f(args...) ``` for CPU devices and ```julia -CUDA.@elapsed expr +CUDA.@allowscalar f(args...) ``` for CUDA devices. + +This is usefully written with closures via +```julia +allowscalar(device) do + f() +end +``` """ -macro elapsed(device, expr) +allowscalar(f, ::AbstractDevice, args...; kwargs...) = f(args...; kwargs...) + +""" + @time device expr + +Device-flexible `@time`. + +Lowers to +```julia +@time expr +``` +for CPU devices and +```julia +CUDA.@time expr +``` +for CUDA devices. +""" +macro time(device, expr) __CC__ = ClimaComms - return esc(quote - if $device isa $CUDADevice - $(__CC__).cuda_elasped($expr) - else - @assert $device isa $AbstractDevice - $Base.@elapsed $(expr) - end - end) + return :($__CC__.time(() -> $(esc(expr)), $(esc(device)))) end -function cuda_sync end - """ @sync device expr @@ -224,18 +309,8 @@ spawned tasks (which require using `Base.sync` or `Threads.wait` to synchronize), then you may want to simply use [`@cuda_sync`](@ref). """ macro sync(device, expr) - # https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207 __CC__ = ClimaComms - return esc(quote - if $device isa $CUDADevice - $(__CC__).cuda_sync($expr) - else - @assert $device isa $AbstractDevice - $Base.@sync begin - $(expr) - end - end - end) + return :($__CC__.sync(() -> $(esc(expr)), $(esc(device)))) end """ @@ -254,38 +329,26 @@ CUDA.@sync expr for CUDA devices. """ macro cuda_sync(device, expr) - # https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207 __CC__ = ClimaComms - return esc(quote - if $device isa $CUDADevice - $(__CC__).cuda_sync($expr) - else - @assert $device isa $AbstractDevice - $(expr) - end - end) + return :($__CC__.cuda_sync(() -> $(esc(expr)), $(esc(device)))) end """ - allowscalar(f, ::AbstractDevice, args...; kwargs...) + @elapsed device expr -Device-flexible version of `CUDA.@allowscalar`. +Device-flexible `@elapsed`. Lowers to ```julia -f(args...) +@elapsed expr ``` for CPU devices and ```julia -CUDA.@allowscalar f(args...) +CUDA.@elapsed expr ``` for CUDA devices. - -This is usefully written with closures via -```julia -allowscalar(device) do - f() -end -``` """ -allowscalar(f, ::AbstractDevice, args...; kwargs...) = f(args...; kwargs...) +macro elapsed(device, expr) + __CC__ = ClimaComms + return :($__CC__.elapsed(() -> $(esc(expr)), $(esc(device)))) +end diff --git a/test/hygiene.jl b/test/hygiene.jl index 645ed59e..ce8ba75b 100644 --- a/test/hygiene.jl +++ b/test/hygiene.jl @@ -1,5 +1,7 @@ import ClimaComms as CC +CC.@import_required_backends function test_macro_hyhiene(dev) + AT = CC.array_type(dev) n = 3 # tests that we can reach variables defined in scope CC.@threaded dev for i in 1:n if Threads.nthreads() > 1 @@ -7,20 +9,44 @@ function test_macro_hyhiene(dev) end end + CC.time(dev) do + for i in 1:n + sin.(AT(rand(1000, 1000))) + end + end + CC.@time dev for i in 1:n - sin.(rand(10)) + sin.(AT(rand(1000, 1000))) + end + + CC.elapsed(dev) do + for i in 1:n + sin.(AT(rand(10))) + end end CC.@elapsed dev for i in 1:n - sin.(rand(10)) + sin.(AT(rand(10))) + end + + CC.sync(dev) do + for i in 1:n + sin.(AT(rand(10))) + end end CC.@sync dev for i in 1:n - sin.(rand(10)) + sin.(AT(rand(10))) + end + + CC.cuda_sync(dev) do + for i in 1:n + sin.(AT(rand(10))) + end end CC.@cuda_sync dev for i in 1:n - sin.(rand(10)) + sin.(AT(rand(10))) end end dev = CC.device()