From 53b3b9f2a5e31703ca61ffd8f38ef7d25973dcc4 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Thu, 25 Apr 2024 15:57:41 -0700 Subject: [PATCH] Add second import guard Mostly so that we can print a log message informing the user we loaded MPI/CUDA --- src/context.jl | 5 +++++ src/devices.jl | 33 ++++++++++++++++++++++----------- src/loading.jl | 38 +++++++++++++++++++------------------- 3 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/context.jl b/src/context.jl index 7cf2b514..ea9584d4 100644 --- a/src/context.jl +++ b/src/context.jl @@ -34,6 +34,11 @@ Behavior can be overridden by setting the `CLIMACOMMS_CONTEXT` environment varia to either `MPI` or `SINGLETON`. """ function context(device = device(); target_context = context_type()) + if target_context == :MPICommsContext && mpi_ext_is_not_loaded() + error( + "Loading MPI.jl is required to use MPICommsContext. You might want to call ClimaComms.@import_required_backends", + ) + end ContextConstructor = getproperty(ClimaComms, target_context) return ContextConstructor(device) end diff --git a/src/devices.jl b/src/devices.jl index c2709496..e7fee85f 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -46,6 +46,21 @@ function device_functional end device_functional(::CPUSingleThreaded) = true device_functional(::CPUMultiThreaded) = true +function device_type() + env_var = get(ENV, "CLIMACOMMS_DEVICE", "CPU") + if env_var == "CPU" + return Threads.nthreads() > 1 ? :CPUMultiThreaded : :CPUSingleThreaded + elseif env_var == "CPUSingleThreaded" + return :CPUSingleThreaded + elseif env_var == "CPUMultiThreaded" + return :CPUMultiThreaded + elseif env_var == "CUDA" + return :CUDADevice + else + error("Invalid CLIMACOMMS_DEVICE: $env_var") + end +end + """ ClimaComms.device() @@ -60,18 +75,14 @@ Allowed values: The default is `CPU`. """ function device() - env_var = get(ENV, "CLIMACOMMS_DEVICE", "CPU") - if env_var == "CPU" - return Threads.nthreads() > 1 ? CPUMultiThreaded() : CPUSingleThreaded() - elseif env_var == "CPUSingleThreaded" - return CPUSingleThreaded() - elseif env_var == "CPUMultiThreaded" - return CPUMultiThreaded() - elseif env_var == "CUDA" - return CUDADevice() - else - error("Invalid CLIMACOMMS_DEVICE: $env_var") + target_device = device_type() + if target_device == :CUDADevice && cuda_ext_is_not_loaded() + error( + "Loading CUDA.jl is required to use CUDADevice. You might want to call ClimaComms.@import_required_backends", + ) end + DeviceConstructor = getproperty(ClimaComms, target_device) + return DeviceConstructor() end """ diff --git a/src/loading.jl b/src/loading.jl index ce6c0c1f..16944a46 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -1,11 +1,21 @@ +import ..ClimaComms + export import_required_backends -function mpi_required() +function mpi_is_required() return context_type() == :MPICommsContext end -function cuda_required() - return device() isa CUDADevice +function mpi_ext_is_not_loaded() + return isnothing(Base.get_extension(ClimaComms, :ClimaCommsMPIExt)) +end + +function cuda_is_required() + return device_type() == :CUDADevice +end + +function cuda_ext_is_not_loaded() + return isnothing(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt)) end """ @@ -16,23 +26,13 @@ If the desired device is CUDA (as determined by `ClimaComms.device()`), try load """ macro import_required_backends() return quote - @static if $mpi_required() - try - import MPI - catch - error( - "Cannot load MPI.jl. Make sure it is included in your environment stack.", - ) - end + @static if $mpi_is_required() && $mpi_ext_is_not_loaded() + import MPI + @info "Loaded MPI.jl" end - @static if $cuda_required() - try - import CUDA - catch - error( - "Cannot load CUDA.jl. Make sure it is included in your environment stack.", - ) - end + @static if $cuda_is_required() && $cuda_ext_is_not_loaded() + import CUDA + @info "Loaded CUDA.jl" end end end