Skip to content

Commit

Permalink
Add second import guard
Browse files Browse the repository at this point in the history
Mostly so that we can print a log message informing the user we loaded
MPI/CUDA
  • Loading branch information
Sbozzolo committed Apr 29, 2024
1 parent 71d8fa7 commit 53b3b9f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 30 deletions.
5 changes: 5 additions & 0 deletions src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ Behavior can be overridden by setting the `CLIMACOMMS_CONTEXT` environment varia
to either `MPI` or `SINGLETON`.
"""
function context(device = device(); target_context = context_type())
if target_context == :MPICommsContext && mpi_ext_is_not_loaded()
error(
"Loading MPI.jl is required to use MPICommsContext. You might want to call ClimaComms.@import_required_backends",
)
end
ContextConstructor = getproperty(ClimaComms, target_context)
return ContextConstructor(device)
end
Expand Down
33 changes: 22 additions & 11 deletions src/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ function device_functional end
device_functional(::CPUSingleThreaded) = true
device_functional(::CPUMultiThreaded) = true

function device_type()
env_var = get(ENV, "CLIMACOMMS_DEVICE", "CPU")
if env_var == "CPU"
return Threads.nthreads() > 1 ? :CPUMultiThreaded : :CPUSingleThreaded
elseif env_var == "CPUSingleThreaded"
return :CPUSingleThreaded
elseif env_var == "CPUMultiThreaded"
return :CPUMultiThreaded
elseif env_var == "CUDA"
return :CUDADevice
else
error("Invalid CLIMACOMMS_DEVICE: $env_var")
end
end

"""
ClimaComms.device()
Expand All @@ -60,18 +75,14 @@ Allowed values:
The default is `CPU`.
"""
function device()
env_var = get(ENV, "CLIMACOMMS_DEVICE", "CPU")
if env_var == "CPU"
return Threads.nthreads() > 1 ? CPUMultiThreaded() : CPUSingleThreaded()
elseif env_var == "CPUSingleThreaded"
return CPUSingleThreaded()
elseif env_var == "CPUMultiThreaded"
return CPUMultiThreaded()
elseif env_var == "CUDA"
return CUDADevice()
else
error("Invalid CLIMACOMMS_DEVICE: $env_var")
target_device = device_type()
if target_device == :CUDADevice && cuda_ext_is_not_loaded()
error(
"Loading CUDA.jl is required to use CUDADevice. You might want to call ClimaComms.@import_required_backends",
)
end
DeviceConstructor = getproperty(ClimaComms, target_device)
return DeviceConstructor()
end

"""
Expand Down
38 changes: 19 additions & 19 deletions src/loading.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import ..ClimaComms

export import_required_backends

function mpi_required()
function mpi_is_required()
return context_type() == :MPICommsContext
end

function cuda_required()
return device() isa CUDADevice
function mpi_ext_is_not_loaded()
return isnothing(Base.get_extension(ClimaComms, :ClimaCommsMPIExt))
end

function cuda_is_required()
return device_type() == :CUDADevice
end

function cuda_ext_is_not_loaded()
return isnothing(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt))
end

"""
Expand All @@ -16,23 +26,13 @@ If the desired device is CUDA (as determined by `ClimaComms.device()`), try load
"""
macro import_required_backends()
return quote
@static if $mpi_required()
try
import MPI
catch
error(
"Cannot load MPI.jl. Make sure it is included in your environment stack.",
)
end
@static if $mpi_is_required() && $mpi_ext_is_not_loaded()
import MPI
@info "Loaded MPI.jl"
end
@static if $cuda_required()
try
import CUDA
catch
error(
"Cannot load CUDA.jl. Make sure it is included in your environment stack.",
)
end
@static if $cuda_is_required() && $cuda_ext_is_not_loaded()
import CUDA
@info "Loaded CUDA.jl"
end
end
end

0 comments on commit 53b3b9f

Please sign in to comment.