Skip to content

Commit

Permalink
Fix macros
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jun 7, 2024
1 parent ae67a34 commit 6d006d2
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 75 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [
"Jake Bolewski <[email protected]>",
"Gabriele Bozzola <[email protected]>",
]
version = "0.6.1"
version = "0.6.2"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
18 changes: 11 additions & 7 deletions ext/ClimaCommsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,27 @@ module ClimaCommsCUDAExt
import CUDA

import ClimaComms
import ClimaComms: CUDADevice

function ClimaComms._assign_device(::ClimaComms.CUDADevice, rank_number)
function ClimaComms._assign_device(::CUDADevice, rank_number)
CUDA.device!(rank_number % CUDA.ndevices())
return nothing
end

function ClimaComms.device_functional(::ClimaComms.CUDADevice)
function ClimaComms.device_functional(::CUDADevice)
return CUDA.functional()
end

ClimaComms.array_type(::ClimaComms.CUDADevice) = CUDA.CuArray
ClimaComms.allowscalar(f, ::ClimaComms.CUDADevice, args...; kwargs...) =
ClimaComms.array_type(::CUDADevice) = CUDA.CuArray
ClimaComms.allowscalar(f, ::CUDADevice, args...; kwargs...) =
CUDA.@allowscalar f(args...; kwargs...)

# Extending ClimaComms methods that operate on expressions (cannot use dispatch here)
ClimaComms.cuda_sync(expr) = CUDA.@sync expr
ClimaComms.cuda_time(expr) = CUDA.@time expr
ClimaComms.cuda_elasped(expr) = CUDA.@elapsed expr
ClimaComms.sync(f::F, ::CUDADevice, args...; kwargs...) where {F} =
CUDA.@sync f(args...; kwargs...)
ClimaComms.time(f::F, ::CUDADevice, args...; kwargs...) where {F} =
CUDA.@time f(args...; kwargs...)
ClimaComms.elapsed(f::F, ::CUDADevice, args...; kwargs...) where {F} =
CUDA.@elapsed f(args...; kwargs...)

end
189 changes: 126 additions & 63 deletions src/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,66 +129,151 @@ macro threaded(device, loop)
end
end

function cuda_time end
"""
@time f(args...; kwargs...)
Device-flexible `@time`:
Calls
```julia
@time f(args...; kwargs...)
```
for CPU devices and
```julia
CUDA.@time f(args...; kwargs...)
```
for CUDA devices.
"""
function time(f::F, device::AbstractDevice, args...; kwargs...) where {F}
Base.@time begin
f(args...; kwargs...)
end
end

"""
@time device expr
elapsed(f::F, device::AbstractDevice, args...; kwargs...)
Device-flexible `@time`.
Device-flexible `elapsed`.
Lowers to
Calls
```julia
@time expr
@elapsed f(args...; kwargs...)
```
for CPU devices and
```julia
CUDA.@time expr
CUDA.@elapsed f(args...; kwargs...)
```
for CUDA devices.
"""
macro time(device, expr)
__CC__ = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(__CC__).cuda_time($expr)
else
@assert $device isa $AbstractDevice
$Base.@time $(expr)
function elapsed(f::F, device::AbstractDevice, args...; kwargs...) where {F}
Base.@elapsed begin
f(args...; kwargs...)
end
end

"""
sync(f, ::AbstractDevice, args...; kwargs...)
Device-flexible function that calls `@sync`.
Calls
```julia
@sync f(args...; kwargs...)
```
for CPU devices and
```julia
CUDA.@sync f(args...; kwargs...)
```
for CUDA devices.
An example use-case of this might be:
```julia
BenchmarkTools.@benchmark begin
if ClimaComms.device() isa ClimaComms.CUDADevice
CUDA.@sync begin
launch_cuda_kernels_or_spawn_tasks!(...)
end
end)
elseif ClimaComms.device() isa ClimaComms.CPUMultiThreading
Base.@sync begin
launch_cuda_kernels_or_spawn_tasks!(...)
end
end
end
```
function cuda_elasped end
If the CPU version of the above example does not leverage
spawned tasks (which require using `Base.sync` or `Threads.wait`
to synchronize), then you may want to simply use [`cuda_sync`](@ref).
"""
function sync(f::F, ::AbstractDevice, args...; kwargs...) where {F}
Base.@sync begin
f(args...; kwargs...)
end
end

"""
@elapsed device expr
cuda_sync(f, ::AbstractDevice, args...; kwargs...)
Device-flexible `@elapsed`.
Device-flexible function that (may) call `CUDA.@sync`.
Calls
```julia
f(args...; kwargs...)
```
for CPU devices and
```julia
CUDA.@sync f(args...; kwargs...)
```
for CUDA devices.
"""
function cuda_sync(f::F, ::AbstractDevice, args...; kwargs...) where {F}
f(args...; kwargs...)
end

"""
allowscalar(f, ::AbstractDevice, args...; kwargs...)
Device-flexible version of `CUDA.@allowscalar`.
Lowers to
```julia
@elapsed expr
f(args...)
```
for CPU devices and
```julia
CUDA.@elapsed expr
CUDA.@allowscalar f(args...)
```
for CUDA devices.
This is usefully written with closures via
```julia
allowscalar(device) do
f()
end
```
"""
macro elapsed(device, expr)
allowscalar(f, ::AbstractDevice, args...; kwargs...) = f(args...; kwargs...)

"""
@time device expr
Device-flexible `@time`.
Lowers to
```julia
@time expr
```
for CPU devices and
```julia
CUDA.@time expr
```
for CUDA devices.
"""
macro time(device, expr)
__CC__ = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(__CC__).cuda_elasped($expr)
else
@assert $device isa $AbstractDevice
$Base.@elapsed $(expr)
end
end)
return :($__CC__.time(() -> $(esc(expr)), $(esc(device))))
end

function cuda_sync end

"""
@sync device expr
Expand Down Expand Up @@ -224,18 +309,8 @@ spawned tasks (which require using `Base.sync` or `Threads.wait`
to synchronize), then you may want to simply use [`@cuda_sync`](@ref).
"""
macro sync(device, expr)
# https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207
__CC__ = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(__CC__).cuda_sync($expr)
else
@assert $device isa $AbstractDevice
$Base.@sync begin
$(expr)
end
end
end)
return :($__CC__.sync(() -> $(esc(expr)), $(esc(device))))
end

"""
Expand All @@ -254,38 +329,26 @@ CUDA.@sync expr
for CUDA devices.
"""
macro cuda_sync(device, expr)
# https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207
__CC__ = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(__CC__).cuda_sync($expr)
else
@assert $device isa $AbstractDevice
$(expr)
end
end)
return :($__CC__.cuda_sync(() -> $(esc(expr)), $(esc(device))))
end

"""
allowscalar(f, ::AbstractDevice, args...; kwargs...)
@elapsed device expr
Device-flexible version of `CUDA.@allowscalar`.
Device-flexible `@elapsed`.
Lowers to
```julia
f(args...)
@elapsed expr
```
for CPU devices and
```julia
CUDA.@allowscalar f(args...)
CUDA.@elapsed expr
```
for CUDA devices.
This is usefully written with closures via
```julia
allowscalar(device) do
f()
end
```
"""
allowscalar(f, ::AbstractDevice, args...; kwargs...) = f(args...; kwargs...)
macro elapsed(device, expr)
__CC__ = ClimaComms
return :($__CC__.elapsed(() -> $(esc(expr)), $(esc(device))))
end
43 changes: 39 additions & 4 deletions test/hygiene.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,61 @@
import ClimaComms as CC
CC.@import_required_backends
function test_macro_hyhiene(dev)
AT = CC.array_type(dev)
n = 3 # tests that we can reach variables defined in scope
# println("----------- @threaded")
CC.@threaded dev for i in 1:n
if Threads.nthreads() > 1
@show Threads.threadid()
end
end

# println("----------- time")
CC.time(dev) do
for i in 1:n
sin.(AT(rand(1000, 1000)))
end
end

# println("----------- @time")
CC.@time dev for i in 1:n
sin.(rand(10))
sin.(AT(rand(1000, 1000)))
end

# println("----------- elapsed")
CC.elapsed(dev) do
for i in 1:n
sin.(AT(rand(10)))
end
end

# println("----------- @elapsed")
CC.@elapsed dev for i in 1:n
sin.(rand(10))
sin.(AT(rand(10)))
end

# println("----------- sync")
CC.sync(dev) do
for i in 1:n
sin.(AT(rand(10)))
end
end

# println("----------- @sync")
CC.@sync dev for i in 1:n
sin.(rand(10))
sin.(AT(rand(10)))
end

# println("----------- cuda_sync")
CC.cuda_sync(dev) do
for i in 1:n
sin.(AT(rand(10)))
end
end

# println("----------- @cuda_sync")
CC.@cuda_sync dev for i in 1:n
sin.(rand(10))
sin.(AT(rand(10)))
end
end
dev = CC.device()
Expand Down

0 comments on commit 6d006d2

Please sign in to comment.