Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify code loading, bump minor version #78

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
agents:
queue: new-central
slurm_mem: 8G
modules: climacommon/2024_03_18
modules: climacommon/2024_05_27

env:
OPENBLAS_NUM_THREADS: 1
Expand Down
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
ClimaComms.jl Release Notes
========================

v0.7.0
-------

- ![][badge-💥breaking] `ClimaComms.@import_required_backends` was removed, as there were code loading issues. It is now recommended to use the following code loading pattern:
```julia
ClimaComms.cuda_is_required() && import CUDA
ClimaComms.mpi_is_required() && import MPI
```

v0.6.0
-------

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [
"Jake Bolewski <[email protected]>",
"Gabriele Bozzola <[email protected]>",
]
version = "0.6.0"
version = "0.7.0"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ ClimaComms
## Loading

```@docs
ClimaComms.import_required_backends
ClimaComms.cuda_is_required
ClimaComms.mpi_is_required
```

## Devices
Expand Down
5 changes: 5 additions & 0 deletions ext/ClimaCommsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@ end

ClimaComms.array_type(::ClimaComms.CUDADevice) = CUDA.CuArray

# Extending ClimaComms methods that operate on expressions (cannot use dispatch here)
ClimaComms.cuda_sync(expr) = CUDA.@sync expr
ClimaComms.cuda_time(expr) = CUDA.@time expr
ClimaComms.cuda_elasped(expr) = CUDA.@elapsed expr

end
5 changes: 0 additions & 5 deletions src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ Behavior can be overridden by setting the `CLIMACOMMS_CONTEXT` environment varia
to either `MPI` or `SINGLETON`.
"""
function context(device = device(); target_context = context_type())
if target_context == :MPICommsContext && mpi_ext_is_not_loaded()
error(
"Loading MPI.jl is required to use MPICommsContext. You might want to call ClimaComms.@import_required_backends",
)
end
charleskawczynski marked this conversation as resolved.
Show resolved Hide resolved
ContextConstructor = getproperty(ClimaComms, target_context)
return ContextConstructor(device)
end
Expand Down
117 changes: 43 additions & 74 deletions src/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ The default is `CPU`.
"""
function device()
target_device = device_type()
if target_device == :CUDADevice && cuda_ext_is_not_loaded()
error(
"Loading CUDA.jl is required to use CUDADevice. You might want to call ClimaComms.@import_required_backends",
)
end
DeviceConstructor = getproperty(ClimaComms, target_device)
return DeviceConstructor()
end
Expand Down Expand Up @@ -129,6 +124,8 @@ macro threaded(device, loop)
end
end

function cuda_time end

"""
@time device expr

Expand All @@ -145,24 +142,19 @@ CUDA.@time expr
for CUDA devices.
"""
macro time(device, expr)
return esc(
quote
if $device isa $CUDADevice
@static if isnothing(
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt),
)
error("CUDA not loaded")
else
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@time $expr
end
else
@assert $device isa $AbstractDevice
$Base.@time $(expr)
end
end,
)
CC = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(CC).cuda_time($expr)
else
@assert $device isa $AbstractDevice
$Base.@time $(expr)
end
end)
end

function cuda_elasped end

"""
@elapsed device expr

Expand All @@ -179,24 +171,19 @@ CUDA.@elapsed expr
for CUDA devices.
"""
macro elapsed(device, expr)
return esc(
quote
if $device isa $CUDADevice
@static if isnothing(
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt),
)
error("CUDA not loaded")
else
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@elapsed $expr
end
else
@assert $device isa $AbstractDevice
$Base.@elapsed $(expr)
end
end,
)
CC = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(CC).cuda_elasped($expr)
else
@assert $device isa $AbstractDevice
$Base.@elapsed $(expr)
end
end)
end

function cuda_sync end

"""
@sync device expr

Expand Down Expand Up @@ -233,26 +220,17 @@ to synchronize), then you may want to simply use [`@cuda_sync`](@ref).
"""
macro sync(device, expr)
# https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207
return esc(
quote
if $device isa $CUDADevice
@static if isnothing(
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt),
)
error("CUDA not loaded")
else
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@sync begin
$(expr)
end
end
else
@assert $device isa $AbstractDevice
$Base.@sync begin
$(expr)
end
CC = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(CC).cuda_sync($expr)
else
@assert $device isa $AbstractDevice
$Base.@sync begin
$(expr)
end
end,
)
end
end)
end

"""
Expand All @@ -272,22 +250,13 @@ for CUDA devices.
"""
macro cuda_sync(device, expr)
# https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207
return esc(
quote
if $device isa $CUDADevice
@static if isnothing(
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt),
)
error("CUDA not loaded")
else
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@sync begin
$(expr)
end
end
else
@assert $device isa $AbstractDevice
$(expr)
end
end,
)
CC = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(CC).cuda_sync($expr)
else
@assert $device isa $AbstractDevice
$(expr)
end
end)
end
49 changes: 19 additions & 30 deletions src/loading.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,27 @@
import ..ClimaComms

export import_required_backends

function mpi_is_required()
return context_type() == :MPICommsContext
end

function mpi_ext_is_not_loaded()
return isnothing(Base.get_extension(ClimaComms, :ClimaCommsMPIExt))
end
"""
mpi_is_required()

function cuda_is_required()
return device_type() == :CUDADevice
end
Returns a Bool indicating if MPI should be loaded, based on the
`ENV["CLIMACOMMS_CONTEXT"]`. See [`ClimaComms.context`](@ref) for
more information.

function cuda_ext_is_not_loaded()
return isnothing(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt))
end
```julia
mpi_is_required() && using MPI
```
"""
mpi_is_required() = context_type() == :MPICommsContext

"""
ClimaComms.@import_required_backends
cuda_is_required()

Returns a Bool indicating if CUDA should be loaded, based on the
`ENV["CLIMACOMMS_DEVICE"]`. See [`ClimaComms.device`](@ref) for
more information.

If the desired context is MPI (as determined by `ClimaComms.context()`), try loading MPI.jl.
If the desired device is CUDA (as determined by `ClimaComms.device()`), try loading CUDA.jl.
```julia
cuda_is_required() && using CUDA
```
"""
macro import_required_backends()
return quote
@static if $mpi_is_required() && $mpi_ext_is_not_loaded()
import MPI
@info "Loaded MPI.jl"
end
@static if $cuda_is_required() && $cuda_ext_is_not_loaded()
import CUDA
@info "Loaded CUDA.jl"
end
end
end
cuda_is_required() = device_type() == :CUDADevice
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test
using ClimaComms
ClimaComms.@import_required_backends
ClimaComms.cuda_is_required() && import CUDA
ClimaComms.mpi_is_required() && import MPI

context = ClimaComms.context()
pid, nprocs = ClimaComms.init(context)
Expand Down
Loading