diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index fc0053c4..d94fcf94 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,7 +1,7 @@ agents: queue: new-central slurm_mem: 8G - modules: climacommon/2024_03_18 + modules: climacommon/2024_05_27 env: OPENBLAS_NUM_THREADS: 1 diff --git a/NEWS.md b/NEWS.md index 3f268454..42de4077 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,15 @@ ClimaComms.jl Release Notes ======================== +v0.7.0 +------- + +- ![][badge-💥breaking] `ClimaComms.@import_required_backends` was removed, as there were code loading issues. It is now recommended to use the following code loading pattern: + ```julia + ClimaComms.cuda_is_required() && import CUDA + ClimaComms.mpi_is_required() && import MPI + ``` + v0.6.0 ------- diff --git a/Project.toml b/Project.toml index be50f10f..c7f28d7f 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,7 @@ authors = [ "Jake Bolewski ", "Gabriele Bozzola ", ] -version = "0.6.0" +version = "0.7.0" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/docs/src/index.md b/docs/src/index.md index d158f3e7..41445852 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -11,7 +11,8 @@ ClimaComms ## Loading ```@docs -ClimaComms.import_required_backends +ClimaComms.cuda_is_required +ClimaComms.mpi_is_required ``` ## Devices diff --git a/ext/ClimaCommsCUDAExt.jl b/ext/ClimaCommsCUDAExt.jl index 3618c62a..3c6cf1d0 100644 --- a/ext/ClimaCommsCUDAExt.jl +++ b/ext/ClimaCommsCUDAExt.jl @@ -15,4 +15,9 @@ end ClimaComms.array_type(::ClimaComms.CUDADevice) = CUDA.CuArray +# Extending ClimaComms methods that operate on expressions (cannot use dispatch here) +ClimaComms.cuda_sync(expr) = CUDA.@sync expr +ClimaComms.cuda_time(expr) = CUDA.@time expr +ClimaComms.cuda_elasped(expr) = CUDA.@elapsed expr + end diff --git a/src/context.jl b/src/context.jl index ea9584d4..7cf2b514 100644 --- a/src/context.jl +++ b/src/context.jl @@ -34,11 +34,6 @@ Behavior can be overridden by setting the `CLIMACOMMS_CONTEXT` environment varia to either `MPI` or `SINGLETON`. """ function context(device = device(); target_context = context_type()) - if target_context == :MPICommsContext && mpi_ext_is_not_loaded() - error( - "Loading MPI.jl is required to use MPICommsContext. You might want to call ClimaComms.@import_required_backends", - ) - end ContextConstructor = getproperty(ClimaComms, target_context) return ContextConstructor(device) end diff --git a/src/devices.jl b/src/devices.jl index e7fee85f..c82cd680 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -76,11 +76,6 @@ The default is `CPU`. """ function device() target_device = device_type() - if target_device == :CUDADevice && cuda_ext_is_not_loaded() - error( - "Loading CUDA.jl is required to use CUDADevice. You might want to call ClimaComms.@import_required_backends", - ) - end DeviceConstructor = getproperty(ClimaComms, target_device) return DeviceConstructor() end @@ -129,6 +124,8 @@ macro threaded(device, loop) end end +function cuda_time end + """ @time device expr @@ -145,24 +142,19 @@ CUDA.@time expr for CUDA devices. """ macro time(device, expr) - return esc( - quote - if $device isa $CUDADevice - @static if isnothing( - $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt), - ) - error("CUDA not loaded") - else - $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@time $expr - end - else - @assert $device isa $AbstractDevice - $Base.@time $(expr) - end - end, - ) + CC = ClimaComms + return esc(quote + if $device isa $CUDADevice + $(CC).cuda_time($expr) + else + @assert $device isa $AbstractDevice + $Base.@time $(expr) + end + end) end +function cuda_elasped end + """ @elapsed device expr @@ -179,24 +171,19 @@ CUDA.@elapsed expr for CUDA devices. """ macro elapsed(device, expr) - return esc( - quote - if $device isa $CUDADevice - @static if isnothing( - $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt), - ) - error("CUDA not loaded") - else - $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@elapsed $expr - end - else - @assert $device isa $AbstractDevice - $Base.@elapsed $(expr) - end - end, - ) + CC = ClimaComms + return esc(quote + if $device isa $CUDADevice + $(CC).cuda_elasped($expr) + else + @assert $device isa $AbstractDevice + $Base.@elapsed $(expr) + end + end) end +function cuda_sync end + """ @sync device expr @@ -233,26 +220,17 @@ to synchronize), then you may want to simply use [`@cuda_sync`](@ref). """ macro sync(device, expr) # https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207 - return esc( - quote - if $device isa $CUDADevice - @static if isnothing( - $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt), - ) - error("CUDA not loaded") - else - $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@sync begin - $(expr) - end - end - else - @assert $device isa $AbstractDevice - $Base.@sync begin - $(expr) - end + CC = ClimaComms + return esc(quote + if $device isa $CUDADevice + $(CC).cuda_sync($expr) + else + @assert $device isa $AbstractDevice + $Base.@sync begin + $(expr) end - end, - ) + end + end) end """ @@ -272,22 +250,13 @@ for CUDA devices. """ macro cuda_sync(device, expr) # https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207 - return esc( - quote - if $device isa $CUDADevice - @static if isnothing( - $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt), - ) - error("CUDA not loaded") - else - $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@sync begin - $(expr) - end - end - else - @assert $device isa $AbstractDevice - $(expr) - end - end, - ) + CC = ClimaComms + return esc(quote + if $device isa $CUDADevice + $(CC).cuda_sync($expr) + else + @assert $device isa $AbstractDevice + $(expr) + end + end) end diff --git a/src/loading.jl b/src/loading.jl index 16944a46..4513bc32 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -1,38 +1,27 @@ import ..ClimaComms -export import_required_backends - -function mpi_is_required() - return context_type() == :MPICommsContext -end - -function mpi_ext_is_not_loaded() - return isnothing(Base.get_extension(ClimaComms, :ClimaCommsMPIExt)) -end +""" + mpi_is_required() -function cuda_is_required() - return device_type() == :CUDADevice -end +Returns a Bool indicating if MPI should be loaded, based on the +`ENV["CLIMACOMMS_CONTEXT"]`. See [`ClimaComms.context`](@ref) for +more information. -function cuda_ext_is_not_loaded() - return isnothing(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt)) -end +```julia +mpi_is_required() && using MPI +``` +""" +mpi_is_required() = context_type() == :MPICommsContext """ - ClimaComms.@import_required_backends + cuda_is_required() + +Returns a Bool indicating if CUDA should be loaded, based on the +`ENV["CLIMACOMMS_DEVICE"]`. See [`ClimaComms.device`](@ref) for +more information. -If the desired context is MPI (as determined by `ClimaComms.context()`), try loading MPI.jl. -If the desired device is CUDA (as determined by `ClimaComms.device()`), try loading CUDA.jl. +```julia +cuda_is_required() && using CUDA +``` """ -macro import_required_backends() - return quote - @static if $mpi_is_required() && $mpi_ext_is_not_loaded() - import MPI - @info "Loaded MPI.jl" - end - @static if $cuda_is_required() && $cuda_ext_is_not_loaded() - import CUDA - @info "Loaded CUDA.jl" - end - end -end +cuda_is_required() = device_type() == :CUDADevice diff --git a/test/runtests.jl b/test/runtests.jl index 5574fdd3..e6dc29c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test using ClimaComms -ClimaComms.@import_required_backends +ClimaComms.cuda_is_required() && import CUDA +ClimaComms.mpi_is_required() && import MPI context = ClimaComms.context() pid, nprocs = ClimaComms.init(context)