Skip to content

Commit

Permalink
Add has_rocm for OpenMPI
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 22, 2024
1 parent 690faae commit 462b225
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/src/reference/library.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ MPI.MPI_LIBRARY_VERSION_STRING
```@docs
MPI.versioninfo
MPI.has_cuda
MPI.has_rocm
MPI.identify_implementation
```
27 changes: 26 additions & 1 deletion src/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ Wtime() = API.MPI_Wtime()
Check if the MPI implementation is known to have CUDA support. Currently only Open MPI
provides a mechanism to check, so it will return `false` with other implementations
(unless overriden).
(unless overriden). For "IBMSpectrumMPI" it will return `true`.
This can be overriden by setting the `JULIA_MPI_HAS_CUDA` environment variable to `true`
or `false`.
Expand All @@ -344,3 +344,28 @@ function has_cuda()
return parse(Bool, flag)
end
end

"""
MPI.has_rocm()
Check if the MPI implementation is known to have ROCm support. Currently only Open MPI
provides a mechanism to check, so it will return `false` with other implementations
(unless overriden).
This can be overriden by setting the `JULIA_MPI_HAS_ROCM` environment variable to `true`
or `false`.
"""
function has_rocm()
flag = get(ENV, "JULIA_MPI_HAS_ROCM", nothing)
if flag === nothing
# Only Open MPI provides a function to check ROCm support
@static if MPI_LIBRARY == "OpenMPI" && MPI_LIBRARY_VERSION v"5"
# int MPIX_Query_rocm_support(void)
return @ccall libmpi.MPIX_Query_rocm_support()::Bool
else
return false
end
else
return parse(Bool, flag)
end
end
6 changes: 6 additions & 0 deletions test/test_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ if get(ENV,"JULIA_MPI_TEST_ARRAYTYPE","") == "CuArray"
@test MPI.has_cuda()
end

@test MPI.has_rocm() isa Bool

if get(ENV,"JULIA_MPI_TEST_ARRAYTYPE","") == "ROCArray"
@test MPI.has_rocm()
end

@test !MPI.Finalized()
MPI.Finalize()
@test MPI.Finalized()

0 comments on commit 462b225

Please sign in to comment.