From 0e711e973cb3ff5d8d55bba98a61e778dc79636b Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Sat, 20 Apr 2024 14:35:06 -0700 Subject: [PATCH] Move MPI and CUDA to Julia extensions --- .buildkite/JuliaProject.toml | 18 -- .buildkite/pipeline.yml | 54 ++-- Project.toml | 10 +- docs/Manifest.toml | 539 ++++++----------------------------- docs/src/index.md | 1 + ext/ClimaCommsCUDAExt.jl | 18 ++ ext/ClimaCommsMPIExt.jl | 274 ++++++++++++++++++ src/context.jl | 53 ++-- src/devices.jl | 137 ++++++--- src/mpi.jl | 259 +---------------- test/Project.toml | 3 - test/basic.jl | 219 -------------- test/runtests.jl | 230 ++++++++++++++- 13 files changed, 781 insertions(+), 1034 deletions(-) delete mode 100644 .buildkite/JuliaProject.toml create mode 100644 ext/ClimaCommsCUDAExt.jl create mode 100644 ext/ClimaCommsMPIExt.jl delete mode 100644 test/basic.jl diff --git a/.buildkite/JuliaProject.toml b/.buildkite/JuliaProject.toml deleted file mode 100644 index a83fee5b..00000000 --- a/.buildkite/JuliaProject.toml +++ /dev/null @@ -1,18 +0,0 @@ -[extras] -CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -HDF5_jll = "0234f1f7-429e-5d53-9886-15a909be8d59" -MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" - -[preferences.CUDA_Runtime_jll] -version = "local" - -[preferences.HDF5_jll] -libhdf5_path = "libhdf5" -libhdf5_hl_path = "libhdf5_hl" - -[preferences.MPIPreferences] -_format = "1.0" -abi = "OpenMPI" -binary = "system" -libmpi = "libmpi" -mpiexec = "srun" diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index b424399d..b3d4b660 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,12 +1,9 @@ agents: - queue: central + queue: new-central slurm_mem: 8G - modules: julia/1.9.3 cuda/12.2 ucx/1.14.1_cuda-12.2 openmpi/4.1.5_cuda-12.2 hdf5/1.12.2-ompi415 nsight-systems/2023.2.1 + modules: climacommon/2024_03_18 env: - JULIA_LOAD_PATH: "${JULIA_LOAD_PATH}:${BUILDKITE_BUILD_CHECKOUT_PATH}/.buildkite" - JULIA_CUDA_USE_BINARYBUILDER: false - JULIA_CUDA_MEMORY_POOL: none OPENBLAS_NUM_THREADS: 1 steps: @@ -14,12 +11,13 @@ steps: key: "initialize" command: - echo "--- Instantiate project" - - "julia --project -e 'using Pkg; Pkg.instantiate(;verbose=true); Pkg.precompile(;strict=true)'" + - "julia --project=test -e 'using Pkg; Pkg.develop(;path=\".\"); Pkg.add(\"CUDA\"); Pkg.add(\"MPI\"); Pkg.instantiate(;verbose=true); Pkg.precompile(;strict=true)'" # force the initialization of the CUDA runtime as it is lazily loaded by default - - "julia --project -e 'using CUDA; CUDA.precompile_runtime()'" - - "julia --project -e 'using Pkg; Pkg.status()'" + - "julia --project=test -e 'using CUDA; CUDA.precompile_runtime()'" + - "julia --project=test -e 'using Pkg; Pkg.status()'" agents: + slurm_gpus: 1 slurm_cpus_per_task: 8 env: JULIA_NUM_PRECOMPILE_TASKS: 8 @@ -29,29 +27,53 @@ steps: - label: ":computer: tests" key: "cpu_tests" command: - - julia --project -e 'using Pkg; Pkg.test()' + - julia --project=test test/runtests.jl + env: + CLIMACOMMS_TEST_DEVICE: CPU + + - label: ":computer: tests MPI" + key: "cpu_tests_mpi" + command: + - srun julia --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CPU agents: - slurm_nodes: 1 - slurm_ntasks_per_node: 4 + slurm_ntasks: 2 - label: ":computer: threaded tests" key: "cpu_threaded_tests" command: - - julia --threads 8 --project -e 'using Pkg; Pkg.test()' + - julia --threads 4 --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CPU agents: - slurm_cpus_per_task: 8 + slurm_cpus_per_task: 4 + + - label: ":computer: threaded tests MPI" + key: "cpu_threaded_tests_mpi" + command: + - srun julia --threads 4 --project=test test/runtests.jl + env: + CLIMACOMMS_TEST_DEVICE: CPU + agents: + slurm_ntasks: 2 + slurm_cpus_per_task: 4 - label: ":flower_playing_cards: tests" key: "gpu_tests" command: - - julia --project -e 'using Pkg; Pkg.test()' + - julia --project=test test/runtests.jl + env: + CLIMACOMMS_TEST_DEVICE: CUDA + agents: + slurm_gpus_per_task: 1 + + - label: ":flower_playing_cards: tests MPI" + key: "gpu_tests_mpi" + command: + - srun julia --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CUDA agents: - slurm_nodes: 1 - slurm_ntasks_per_node: 2 slurm_gpus_per_task: 1 + slurm_ntasks: 2 diff --git a/Project.toml b/Project.toml index 05914227..7f3a1ecc 100644 --- a/Project.toml +++ b/Project.toml @@ -8,13 +8,17 @@ authors = [ "Jake Bolewski ", "Gabriele Bozzola ", ] -version = "0.5.8" +version = "0.5.9" -[deps] +[weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +[extensions] +ClimaCommsCUDAExt = "CUDA" +ClimaCommsMPIExt = "MPI" + [compat] CUDA = "3, 4, 5" MPI = "0.20.18" -julia = "1.8" +julia = "1.9" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 1f6a1426..9740e5ca 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.3" +julia_version = "1.10.2" manifest_format = "2.0" project_hash = "2a3f6f2093cc9e00b435b32de577d1b8ccb4f7ac" @@ -9,34 +9,10 @@ git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" version = "0.0.1" -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - - [deps.AbstractFFTs.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - [[deps.AbstractTrees]] -git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" +git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.4.4" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.6.2" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" +version = "0.4.5" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -45,124 +21,32 @@ version = "1.1.1" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.4.2" - [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[deps.CEnum]] -git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.2" - -[[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "Statistics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "f062a48c26ae027f70c44f48f244862aec47bf99" -uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.0.0" - - [deps.CUDA.extensions] - SpecialFunctionsExt = "SpecialFunctions" - - [deps.CUDA.weakdeps] - SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" - -[[deps.CUDA_Driver_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "35a37bb72b35964f2895c12c687ae263b4ac170c" -uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.6.0+3" - -[[deps.CUDA_Runtime_Discovery]] -deps = ["Libdl"] -git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536" -uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" -version = "0.2.2" - -[[deps.CUDA_Runtime_jll]] -deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "bfe5a693a11522d58392f742243f2b50dc27afd6" -uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.9.2+0" - [[deps.ClimaComms]] -deps = ["CUDA", "MPI"] path = ".." uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" -version = "0.5.4" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.4" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.10" - -[[deps.Compat]] -deps = ["UUIDs"] -git-tree-sha1 = "8a62af3e248a8c4bad6b32cbbe663ae02275e32c" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.10.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+0" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataAPI]] -git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.15.0" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.15" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" +version = "0.6.0" + + [deps.ClimaComms.extensions] + ClimaCommsCUDAExt = "CUDA" + ClimaCommsMPIExt = "MPI" + + [deps.ClimaComms.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.4" [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" @@ -170,78 +54,47 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" version = "0.9.3" [[deps.Documenter]] -deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "Dates", "DocStringExtensions", "Downloads", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "Test", "Unicode"] -git-tree-sha1 = "f667b805e90d643aeb1ca70189827f991a7cc115" +deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"] +git-tree-sha1 = "f15a91e6e3919055efa4f206f942a73fedf5dfe6" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "1.1.0" +version = "1.4.0" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" -[[deps.ExprTools]] -git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.10" +[[deps.Expat_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "4558ab818dcceaab612d1bb8c19cee87eda2b83c" +uuid = "2e619515-83b5-522b-bb60-26c02a35a201" +version = "2.5.0+0" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "8ad8f375ae365aa1eb2f42e2565a40b55a4b69a8" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "9.0.0" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.5" - -[[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "5e4487558477f191c043166f8301dd0b4be4e2b2" -uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.24.5" +[[deps.Git]] +deps = ["Git_jll"] +git-tree-sha1 = "04eff47b1354d702c3a85e8ab23d539bb7d5957e" +uuid = "d7ba0133-e1db-5d97-8f8c-041e4b3a1eb2" +version = "1.3.1" + +[[deps.Git_jll]] +deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"] +git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809" +uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" +version = "2.44.0+2" [[deps.IOCapture]] deps = ["Logging", "Random"] -git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" +git-tree-sha1 = "8b72179abc660bfab5e28472e019392b97d0985c" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.3" - -[[deps.InlineStrings]] -deps = ["Parsers"] -git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.0" +version = "0.2.4" [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -[[deps.InvertedIndices]] -git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.3.0" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" @@ -254,223 +107,106 @@ git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.4" -[[deps.JuliaNVTXCallbacks_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" -uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e" -version = "0.2.1+0" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.8" - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "a9d2ce1d5007b1e8f6c5b89c5a31ff8bd146db5c" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.2.1" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "7ca6850ae880cc99b59b88517545f91a52020afa" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.25+0" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "f2355693d6778a178ade15952b7ac47a4ff97996" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.0" - [[deps.LazilyInitializedFields]] -git-tree-sha1 = "410fe4739a4b092f2ffe36fcb0dcc3ab12648ce1" +git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612" uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf" -version = "1.2.1" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +version = "1.2.2" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" +version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" +version = "8.4.0+0" [[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" +version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.17.0+0" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[deps.MPI]] -deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "PkgVersion", "PrecompileTools", "Requires", "Serialization", "Sockets"] -git-tree-sha1 = "b4d8707e42b693720b54f0b3434abee6dd4d947a" -uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195" -version = "0.20.16" - - [deps.MPI.extensions] - AMDGPUExt = "AMDGPU" - CUDAExt = "CUDA" - - [deps.MPI.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - -[[deps.MPICH_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "8a5b4d2220377d1ece13f49438d71ad20cf1ba83" -uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.1.2+0" - -[[deps.MPIPreferences]] -deps = ["Libdl", "Preferences"] -git-tree-sha1 = "781916a2ebf2841467cda03b6f1af43e23839d85" -uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" -version = "0.1.9" - -[[deps.MPItrampoline_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "6979eccb6a9edbbb62681e158443e79ecc0d056a" -uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" -version = "5.3.1+0" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.11" - [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[deps.MarkdownAST]] deps = ["AbstractTrees", "Markdown"] -git-tree-sha1 = "e8513266815200c0c8f522d6d44ffb5e9b366ae4" +git-tree-sha1 = "465a70f0fc7d443a00dcdc3267a497397b8a3899" uuid = "d0879d2d-cac2-40c8-9cee-1863dc0c7391" -version = "0.1.1" +version = "0.1.2" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" - -[[deps.MicrosoftMPI_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "a8027af3d1743b3bfae34e54872359fdebb31422" -uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" -version = "10.1.3+4" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.1.0" +version = "2.28.2+1" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" - -[[deps.NVTX]] -deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] -git-tree-sha1 = "8bc9ce4233be3c63f8dcd78ccaf1b63a9c0baa34" -uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" -version = "0.3.3" - -[[deps.NVTX_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b" -uuid = "e98f9f5b-d649-5603-91fd-7774390e6439" -version = "3.1.0+2" +version = "2023.1.10" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+4" - -[[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "f3080f4212a8ba2ceb10a34b938601b862094314" -uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "4.1.5+0" +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.13+1" -[[deps.OrderedCollections]] -git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.2" +[[deps.PCRE2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" +version = "10.42.0+1" [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.7.2" +version = "2.8.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.2" - -[[deps.PkgVersion]] -deps = ["Pkg"] -git-tree-sha1 = "f9501cc0430a26bc3d156ae1b5b0c1b47af4d6da" -uuid = "eebad327-c553-4316-9ea0-9fa01ccd7688" -version = "0.3.3" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.3" +version = "1.10.0" [[deps.PrecompileTools]] deps = ["Preferences"] -git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.0" +version = "1.2.1" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e" +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.1" - -[[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "ee094908d720185ddbdc58dbe0c1cbe35453ec7a" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.2.7" +version = "1.4.3" [[deps.Printf]] deps = ["Unicode"] @@ -481,118 +217,30 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[deps.Random]] -deps = ["SHA", "Serialization"] +deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[[deps.Random123]] -deps = ["Random", "RandomNumbers"] -git-tree-sha1 = "552f30e847641591ba3f39fd1bed559b9deb0ef3" -uuid = "74087812-796a-5b5d-8853-05524746bad3" -version = "1.6.1" - -[[deps.RandomNumbers]] -deps = ["Random", "Requires"] -git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" -uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" -version = "1.5.3" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - [[deps.RegistryInstances]] deps = ["LazilyInitializedFields", "Pkg", "TOML", "Tar"] git-tree-sha1 = "ffd19052caf598b8653b99404058fce14828be51" uuid = "2792f1a3-b283-48e8-9a74-f99dce5104f3" version = "0.1.0" -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.0" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "04bdff0b09c65ff3e06a05e3eb7b120223da3d39" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.0" - [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore"] -git-tree-sha1 = "0adf069a2a490c47273727e029371b31d44b72b2" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.5" -weakdeps = ["Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" - -[[deps.StringManipulation]] -deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" -uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+6" - [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" version = "1.0.3" -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "a1f34829d5ac0ef499f6d84428bd6b4c71f02ead" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.11.0" - [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" @@ -602,11 +250,14 @@ version = "1.10.0" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[deps.TimerOutputs]] -deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.23" +[[deps.TranscodingStreams]] +git-tree-sha1 = "71509f04d045ec714c4748c785a59045c3736349" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.10.7" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] [[deps.UUIDs]] deps = ["Random", "SHA"] @@ -615,33 +266,17 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.3" - [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+0" +version = "1.2.13+1" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" +version = "1.52.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" +version = "17.4.0+2" diff --git a/docs/src/index.md b/docs/src/index.md index 91f90f74..26f69cdd 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -17,6 +17,7 @@ ClimaComms.CPUSingleThreading ClimaComms.CPUMultiThreading ClimaComms.CUDADevice ClimaComms.device +ClimaComms.device_functional ClimaComms.array_type ClimaComms.@threaded ClimaComms.@time diff --git a/ext/ClimaCommsCUDAExt.jl b/ext/ClimaCommsCUDAExt.jl new file mode 100644 index 00000000..3618c62a --- /dev/null +++ b/ext/ClimaCommsCUDAExt.jl @@ -0,0 +1,18 @@ +module ClimaCommsCUDAExt + +import CUDA + +import ClimaComms + +function ClimaComms._assign_device(::ClimaComms.CUDADevice, rank_number) + CUDA.device!(rank_number % CUDA.ndevices()) + return nothing +end + +function ClimaComms.device_functional(::ClimaComms.CUDADevice) + return CUDA.functional() +end + +ClimaComms.array_type(::ClimaComms.CUDADevice) = CUDA.CuArray + +end diff --git a/ext/ClimaCommsMPIExt.jl b/ext/ClimaCommsMPIExt.jl new file mode 100644 index 00000000..d04732b9 --- /dev/null +++ b/ext/ClimaCommsMPIExt.jl @@ -0,0 +1,274 @@ +module ClimaCommsMPIExt + +import MPI +import ClimaComms + +ClimaComms.MPICommsContext(device = ClimaComms.device()) = + ClimaComms.MPICommsContext(device, MPI.COMM_WORLD) + +function ClimaComms.init(ctx::ClimaComms.MPICommsContext) + if !MPI.Initialized() + MPI.Init() + end + # TODO: Generalize this to arbitrary accelerators + if ctx.device isa ClimaComms.CUDADevice + if !MPI.has_cuda() + error( + "MPI implementation is not built with CUDA-aware interface. If your MPI is not OpenMPI, you have to set JULIA_MPI_HAS_CUDA to `true`", + ) + end + # assign GPUs based on local rank + local_comm = MPI.Comm_split_type( + ctx.mpicomm, + MPI.COMM_TYPE_SHARED, + MPI.Comm_rank(ctx.mpicomm), + ) + ClimaComms._assign_device(ctx.device, MPI.Comm_rank(local_comm)) + MPI.free(local_comm) + end + return ClimaComms.mypid(ctx), ClimaComms.nprocs(ctx) +end + +ClimaComms.device(ctx::ClimaComms.MPICommsContext) = ctx.device + +ClimaComms.mypid(ctx::ClimaComms.MPICommsContext) = + MPI.Comm_rank(ctx.mpicomm) + 1 +ClimaComms.iamroot(ctx::ClimaComms.MPICommsContext) = ClimaComms.mypid(ctx) == 1 +ClimaComms.nprocs(ctx::ClimaComms.MPICommsContext) = MPI.Comm_size(ctx.mpicomm) + +ClimaComms.barrier(ctx::ClimaComms.MPICommsContext) = MPI.Barrier(ctx.mpicomm) + +ClimaComms.reduce(ctx::ClimaComms.MPICommsContext, val, op) = + MPI.Reduce(val, op, 0, ctx.mpicomm) + +ClimaComms.reduce!(ctx::ClimaComms.MPICommsContext, sendbuf, recvbuf, op) = + MPI.Reduce!(sendbuf, recvbuf, op, ctx.mpicomm; root = 0) + +ClimaComms.reduce!(ctx::ClimaComms.MPICommsContext, sendrecvbuf, op) = + MPI.Reduce!(sendrecvbuf, op, ctx.mpicomm; root = 0) + +ClimaComms.allreduce(ctx::ClimaComms.MPICommsContext, sendbuf, op) = + MPI.Allreduce(sendbuf, op, ctx.mpicomm) + +ClimaComms.allreduce!(ctx::ClimaComms.MPICommsContext, sendbuf, recvbuf, op) = + MPI.Allreduce!(sendbuf, recvbuf, op, ctx.mpicomm) + +ClimaComms.allreduce!(ctx::ClimaComms.MPICommsContext, sendrecvbuf, op) = + MPI.Allreduce!(sendrecvbuf, op, ctx.mpicomm) + +ClimaComms.bcast(ctx::ClimaComms.MPICommsContext, object) = + MPI.bcast(object, ctx.mpicomm; root = 0) + +function ClimaComms.gather(ctx::ClimaComms.MPICommsContext, array) + dims = size(array) + lengths = MPI.Gather(dims[end], 0, ctx.mpicomm) + if ClimaComms.iamroot(ctx) + dimsout = (dims[1:(end - 1)]..., sum(lengths)) + arrayout = similar(array, dimsout) + recvbuf = MPI.VBuffer(arrayout, lengths .* prod(dims[1:(end - 1)])) + else + recvbuf = nothing + end + MPI.Gatherv!(array, recvbuf, 0, ctx.mpicomm) +end + +ClimaComms.abort(ctx::ClimaComms.MPICommsContext, status::Int) = + MPI.Abort(ctx.mpicomm, status) + + +# We could probably do something fancier here? +# Would need to be careful as there is no guarantee that all ranks will call +# finalizers at the same time. +const TAG = Ref(Cint(0)) +function newtag(ctx::ClimaComms.MPICommsContext) + TAG[] = tag = mod(TAG[] + 1, 32767) # TODO: this should query MPI_TAG_UB attribute (https://github.com/JuliaParallel/MPI.jl/pull/551) + if tag == 0 + @warn("MPICommsMPI: tag overflow") + end + return tag +end + +""" + MPISendRecvGraphContext + +A simple ghost buffer implementation using MPI `Isend`/`Irecv` operations. +""" +mutable struct MPISendRecvGraphContext <: ClimaComms.AbstractGraphContext + ctx::ClimaComms.MPICommsContext + tag::Cint + send_bufs::Vector{MPI.Buffer} + send_ranks::Vector{Cint} + send_reqs::MPI.UnsafeMultiRequest + recv_bufs::Vector{MPI.Buffer} + recv_ranks::Vector{Cint} + recv_reqs::MPI.UnsafeMultiRequest +end + +""" + MPIPersistentSendRecvGraphContext + +A simple ghost buffer implementation using MPI persistent send/receive operations. +""" +struct MPIPersistentSendRecvGraphContext <: ClimaComms.AbstractGraphContext + ctx::ClimaComms.MPICommsContext + tag::Cint + send_bufs::Vector{MPI.Buffer} + send_ranks::Vector{Cint} + send_reqs::MPI.UnsafeMultiRequest + recv_bufs::Vector{MPI.Buffer} + recv_ranks::Vector{Cint} + recv_reqs::MPI.UnsafeMultiRequest +end + +function graph_context( + ctx::ClimaComms.MPICommsContext, + send_array, + send_lengths, + send_pids, + recv_array, + recv_lengths, + recv_pids, + ::Type{GCT}, +) where { + GCT <: Union{MPISendRecvGraphContext, MPIPersistentSendRecvGraphContext}, +} + @assert length(send_pids) == length(send_lengths) + @assert length(recv_pids) == length(recv_lengths) + + tag = newtag(ctx) + + send_bufs = MPI.Buffer[] + total_len = 0 + for len in send_lengths + buf = MPI.Buffer(view(send_array, (total_len + 1):(total_len + len))) + push!(send_bufs, buf) + total_len += len + end + send_ranks = Cint[pid - 1 for pid in send_pids] + send_reqs = MPI.UnsafeMultiRequest(length(send_ranks)) + + recv_bufs = MPI.Buffer[] + total_len = 0 + for len in recv_lengths + buf = MPI.Buffer(view(recv_array, (total_len + 1):(total_len + len))) + push!(recv_bufs, buf) + total_len += len + end + recv_ranks = Cint[pid - 1 for pid in recv_pids] + recv_reqs = MPI.UnsafeMultiRequest(length(recv_ranks)) + args = ( + ctx, + tag, + send_bufs, + send_ranks, + send_reqs, + recv_bufs, + recv_ranks, + recv_reqs, + ) + if GCT == MPIPersistentSendRecvGraphContext + # Allocate a persistent receive request + for n in 1:length(recv_bufs) + MPI.Recv_init( + recv_bufs[n], + ctx.mpicomm, + recv_reqs[n]; + source = recv_ranks[n], + tag = tag, + ) + end + # Allocate a persistent send request + for n in 1:length(send_bufs) + MPI.Send_init( + send_bufs[n], + ctx.mpicomm, + send_reqs[n]; + dest = send_ranks[n], + tag = tag, + ) + end + MPIPersistentSendRecvGraphContext(args...) + else + MPISendRecvGraphContext(args...) + end +end + +ClimaComms.graph_context( + ctx::ClimaComms.MPICommsContext, + send_array, + send_lengths, + send_pids, + recv_array, + recv_lengths, + recv_pids; + persistent::Bool = true, +) = graph_context( + ctx, + send_array, + send_lengths, + send_pids, + recv_array, + recv_lengths, + recv_pids, + persistent ? MPIPersistentSendRecvGraphContext : MPISendRecvGraphContext, +) + +function ClimaComms.start( + ghost::MPISendRecvGraphContext; + dependencies = nothing, +) + if !all(MPI.isnull, ghost.recv_reqs) + error("Must finish() before next start()") + end + # post receives + for n in 1:length(ghost.recv_bufs) + MPI.Irecv!( + ghost.recv_bufs[n], + ghost.recv_ranks[n], + ghost.tag, + ghost.ctx.mpicomm, + ghost.recv_reqs[n], + ) + end + # post sends + for n in 1:length(ghost.send_bufs) + MPI.Isend( + ghost.send_bufs[n], + ghost.send_ranks[n], + ghost.tag, + ghost.ctx.mpicomm, + ghost.send_reqs[n], + ) + end +end + +function ClimaComms.start( + ghost::MPIPersistentSendRecvGraphContext; + dependencies = nothing, +) + MPI.Startall(ghost.recv_reqs) # post receives + MPI.Startall(ghost.send_reqs) # post sends +end + +function ClimaComms.progress( + ghost::Union{MPISendRecvGraphContext, MPIPersistentSendRecvGraphContext}, +) + if isdefined(MPI, :MPI_ANY_SOURCE) # < v0.20 + MPI.Iprobe(MPI.MPI_ANY_SOURCE, ghost.tag, ghost.ctx.mpicomm) + else # >= v0.20 + MPI.Iprobe(MPI.ANY_SOURCE, ghost.tag, ghost.ctx.mpicomm) + end +end + +function ClimaComms.finish( + ghost::Union{MPISendRecvGraphContext, MPIPersistentSendRecvGraphContext}; + dependencies = nothing, +) + # wait on previous receives + MPI.Waitall(ghost.recv_reqs) + # ensure that sends have completed + # TODO: these could be moved to start()? but we would need to add a finalizer to make sure they complete. + MPI.Waitall(ghost.send_reqs) +end + +end diff --git a/src/context.jl b/src/context.jl index b06326fc..6ff13322 100644 --- a/src/context.jl +++ b/src/context.jl @@ -1,3 +1,16 @@ +import ..ClimaComms + +""" + ClimaComms.mpi_ext_available() + +Returns true when the `ClimaComms` `ClimaCommsMPIExt` extension was loaded. + +To load `ClimaCommsMPIExt`, just load `ClimaComms` and `MPI`. +""" +function mpi_ext_available() + return !isnothing(Base.get_extension(ClimaComms, :ClimaCommsMPIExt)) +end + """ ClimaComms.context(device=device()) @@ -11,24 +24,28 @@ Behavior can be overridden by setting the `CLIMACOMMS_CONTEXT` environment varia to either `MPI` or `SINGLETON`. """ function context(device = device()) - name = get(ENV, "CLIMACOMMS_CONTEXT", nothing) - if !isnothing(name) - if name == "MPI" - return MPICommsContext() - elseif name == "SINGLETON" - return SingletonCommsContext() + if !(mpi_ext_available()) + return SingletonCommsContext(device) + else + name = get(ENV, "CLIMACOMMS_CONTEXT", nothing) + if !isnothing(name) + if name == "MPI" + return MPICommsContext() + elseif name == "SINGLETON" + return SingletonCommsContext() + else + error("Invalid context: $name") + end + end + # detect common environment variables used by MPI launchers + # PMI_RANK appears to be used by MPICH and srun + # OMPI_COMM_WORLD_RANK appears to be used by OpenMPI + if haskey(ENV, "PMI_RANK") || haskey(ENV, "OMPI_COMM_WORLD_RANK") + return MPICommsContext(device) else - error("Invalid context: $name") + return SingletonCommsContext(device) end end - # detect common environment variables used by MPI launchers - # PMI_RANK appears to be used by MPICH and srun - # OMPI_COMM_WORLD_RANK appears to be used by OpenMPI - if haskey(ENV, "PMI_RANK") || haskey(ENV, "OMPI_COMM_WORLD_RANK") - return MPICommsContext(device) - else - return SingletonCommsContext(device) - end end """ @@ -149,9 +166,9 @@ A context for communicating between processes in a graph. abstract type AbstractGraphContext end """ - graph_context(context::AbstractCommsContext, - sendarray, sendlengths, sendpids, - recvarray, recvlengths, recvpids) + graph_context(context::AbstractCommsContext, + sendarray, sendlengths, sendpids, + recvarray, recvlengths, recvpids) Construct a communication context for exchanging neighbor data via a graph. diff --git a/src/devices.jl b/src/devices.jl index d05efbfe..2cd77480 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -1,6 +1,4 @@ -# we previously used CUDA as a variable -import CUDA - +import ..ClimaComms """ AbstractDevice @@ -38,6 +36,27 @@ Use NVIDIA GPU accelarator """ struct CUDADevice <: AbstractDevice end +""" + ClimaComms.cuda_ext_available() + +Returns true when the `ClimaComms` `ClimaCommsCUDAExt` extension was loaded. + +To load `ClimaCommsCUDAExt`, just load `ClimaComms` and `CUDA`. +""" +function cuda_ext_available() + return !isnothing(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt)) +end + +""" + ClimaComms.device_functional(device) + +Return true when the `device` is correctly set up. +""" +function device_functional end + +device_functional(::CPUSingleThreaded) = true +device_functional(::CPUMultiThreaded) = true + """ ClimaComms.device() @@ -60,12 +79,13 @@ function device() elseif env_var == "CPUMultiThreaded" return CPUMultiThreaded() elseif env_var == "CUDA" + cuda_ext_available() || error("CUDA was not loaded") return CUDADevice() else error("Invalid CLIMACOMMS_DEVICE: $env_var") end end - if CUDA.functional() + if cuda_ext_available() && device_functional(CUDADevice()) return CUDADevice() else return Threads.nthreads() == 1 ? CPUSingleThreaded() : @@ -79,8 +99,13 @@ end The base array type used by the specified device (currently `Array` or `CuArray`). """ array_type(::AbstractCPUDevice) = Array -array_type(::CUDADevice) = CUDA.CuArray +""" +Internal function that can be used to assign a device to a process. + +Currently used to assign CUDADevices to MPI ranks. +""" +_assign_device(device, id) = nothing """ @threaded device for ... end @@ -128,14 +153,22 @@ CUDA.@time expr for CUDA devices. """ macro time(device, expr) - return quote - if $(esc(device)) isa CUDADevice - CUDA.@time $(esc(expr)) - else - @assert $(esc(device)) isa AbstractDevice - Base.@time $(esc(expr)) - end - end + return esc( + quote + if $device isa $CUDADevice + @static if isnothing( + $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt), + ) + error("CUDA not loaded") + else + $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@time $expr + end + else + @assert $device isa $AbstractDevice + $Base.@time $(expr) + end + end, + ) end """ @@ -154,14 +187,22 @@ CUDA.@elapsed expr for CUDA devices. """ macro elapsed(device, expr) - return quote - if $(esc(device)) isa CUDADevice - CUDA.@elapsed $(esc(expr)) - else - @assert $(esc(device)) isa AbstractDevice - Base.@elapsed $(esc(expr)) - end - end + return esc( + quote + if $device isa $CUDADevice + @static if isnothing( + $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt), + ) + error("CUDA not loaded") + else + $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@elapsed $expr + end + else + @assert $device isa $AbstractDevice + $Base.@elapsed $(expr) + end + end, + ) end """ @@ -200,18 +241,26 @@ to synchronize), then you may want to simply use [`@cuda_sync`](@ref). """ macro sync(device, expr) # https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207 - return esc(quote - if $(device) isa $CUDADevice - $CUDA.@sync begin - $(expr) - end - else - @assert $(device) isa $AbstractDevice - $Base.@sync begin - $(expr) + return esc( + quote + if $device isa $CUDADevice + @static if isnothing( + $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt), + ) + error("CUDA not loaded") + else + $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@sync begin + $(expr) + end + end + else + @assert $device isa $AbstractDevice + $Base.@sync begin + $(expr) + end end - end - end) + end, + ) end """ @@ -231,14 +280,22 @@ for CUDA devices. """ macro cuda_sync(device, expr) # https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207 - return esc(quote - if $(device) isa $CUDADevice - $CUDA.@sync begin + return esc( + quote + if $device isa $CUDADevice + @static if isnothing( + $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt), + ) + error("CUDA not loaded") + else + $Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@sync begin + $(expr) + end + end + else + @assert $device isa $AbstractDevice $(expr) end - else - @assert $(device) isa $AbstractDevice - $(expr) - end - end) + end, + ) end diff --git a/src/mpi.jl b/src/mpi.jl index 3fcaadfd..524cc2cc 100644 --- a/src/mpi.jl +++ b/src/mpi.jl @@ -1,5 +1,3 @@ -using MPI - """ MPICommsContext() MPICommsContext(device) @@ -8,262 +6,9 @@ using MPI A MPI communications context, used for distributed runs. [`AbstractCPUDevice`](@ref) and [`CUDADevice`](@ref) device options are currently supported. """ -struct MPICommsContext{D <: AbstractDevice, C <: MPI.Comm} <: - AbstractCommsContext +struct MPICommsContext{D <: AbstractDevice, C} <: AbstractCommsContext device::D mpicomm::C end -MPICommsContext(device = device()) = MPICommsContext(device, MPI.COMM_WORLD) - -device(ctx::MPICommsContext) = ctx.device - - -function init(ctx::MPICommsContext) - if !MPI.Initialized() - MPI.Init() - end - if ctx.device isa CUDADevice - if !MPI.has_cuda() - error("MPI implementation is not built with CUDA-aware interface") - end - # assign GPUs based on local rank - local_comm = MPI.Comm_split_type( - ctx.mpicomm, - MPI.COMM_TYPE_SHARED, - MPI.Comm_rank(ctx.mpicomm), - ) - CUDA.device!(MPI.Comm_rank(local_comm) % CUDA.ndevices()) - MPI.free(local_comm) - end - return mypid(ctx), nprocs(ctx) -end - -mypid(ctx::MPICommsContext) = MPI.Comm_rank(ctx.mpicomm) + 1 -iamroot(ctx::MPICommsContext) = mypid(ctx) == 1 -nprocs(ctx::MPICommsContext) = MPI.Comm_size(ctx.mpicomm) - -barrier(ctx::MPICommsContext) = MPI.Barrier(ctx.mpicomm) - -reduce(ctx::MPICommsContext, val, op) = MPI.Reduce(val, op, 0, ctx.mpicomm) - -reduce!(ctx::MPICommsContext, sendbuf, recvbuf, op) = - MPI.Reduce!(sendbuf, recvbuf, op, ctx.mpicomm; root = 0) - -reduce!(ctx::MPICommsContext, sendrecvbuf, op) = - MPI.Reduce!(sendrecvbuf, op, ctx.mpicomm; root = 0) - -allreduce(ctx::MPICommsContext, sendbuf, op) = - MPI.Allreduce(sendbuf, op, ctx.mpicomm) - -allreduce!(ctx::MPICommsContext, sendbuf, recvbuf, op) = - MPI.Allreduce!(sendbuf, recvbuf, op, ctx.mpicomm) - -allreduce!(ctx::MPICommsContext, sendrecvbuf, op) = - MPI.Allreduce!(sendrecvbuf, op, ctx.mpicomm) - -bcast(ctx::MPICommsContext, object) = MPI.bcast(object, ctx.mpicomm; root = 0) - -function gather(ctx::MPICommsContext, array) - dims = size(array) - lengths = MPI.Gather(dims[end], 0, ctx.mpicomm) - if iamroot(ctx) - dimsout = (dims[1:(end - 1)]..., sum(lengths)) - arrayout = similar(array, dimsout) - recvbuf = MPI.VBuffer(arrayout, lengths .* prod(dims[1:(end - 1)])) - else - recvbuf = nothing - end - MPI.Gatherv!(array, recvbuf, 0, ctx.mpicomm) -end - -abort(ctx::MPICommsContext, status::Int) = MPI.Abort(ctx.mpicomm, status) - - -# We could probably do something fancier here? -# Would need to be careful as there is no guarantee that all ranks will call -# finalizers at the same time. -const TAG = Ref(Cint(0)) -function newtag(ctx::MPICommsContext) - TAG[] = tag = mod(TAG[] + 1, 32767) # TODO: this should query MPI_TAG_UB attribute (https://github.com/JuliaParallel/MPI.jl/pull/551) - if tag == 0 - @warn("MPICommsMPI: tag overflow") - end - return tag -end - -""" - MPISendRecvGraphContext -A simple ghost buffer implementation using MPI `Isend`/`Irecv` operations. -""" -mutable struct MPISendRecvGraphContext <: AbstractGraphContext - ctx::MPICommsContext - tag::Cint - send_bufs::Vector{MPI.Buffer} - send_ranks::Vector{Cint} - send_reqs::MPI.UnsafeMultiRequest - recv_bufs::Vector{MPI.Buffer} - recv_ranks::Vector{Cint} - recv_reqs::MPI.UnsafeMultiRequest -end - -""" - MPIPersistentSendRecvGraphContext - -A simple ghost buffer implementation using MPI persistent send/receive operations. -""" -struct MPIPersistentSendRecvGraphContext <: AbstractGraphContext - ctx::MPICommsContext - tag::Cint - send_bufs::Vector{MPI.Buffer} - send_ranks::Vector{Cint} - send_reqs::MPI.UnsafeMultiRequest - recv_bufs::Vector{MPI.Buffer} - recv_ranks::Vector{Cint} - recv_reqs::MPI.UnsafeMultiRequest -end - -function graph_context( - ctx::MPICommsContext, - send_array, - send_lengths, - send_pids, - recv_array, - recv_lengths, - recv_pids, - ::Type{GCT}, -) where { - GCT <: Union{MPISendRecvGraphContext, MPIPersistentSendRecvGraphContext}, -} - @assert length(send_pids) == length(send_lengths) - @assert length(recv_pids) == length(recv_lengths) - - tag = newtag(ctx) - - send_bufs = MPI.Buffer[] - total_len = 0 - for len in send_lengths - buf = MPI.Buffer(view(send_array, (total_len + 1):(total_len + len))) - push!(send_bufs, buf) - total_len += len - end - send_ranks = Cint[pid - 1 for pid in send_pids] - send_reqs = MPI.UnsafeMultiRequest(length(send_ranks)) - - recv_bufs = MPI.Buffer[] - total_len = 0 - for len in recv_lengths - buf = MPI.Buffer(view(recv_array, (total_len + 1):(total_len + len))) - push!(recv_bufs, buf) - total_len += len - end - recv_ranks = Cint[pid - 1 for pid in recv_pids] - recv_reqs = MPI.UnsafeMultiRequest(length(recv_ranks)) - args = ( - ctx, - tag, - send_bufs, - send_ranks, - send_reqs, - recv_bufs, - recv_ranks, - recv_reqs, - ) - if GCT == MPIPersistentSendRecvGraphContext - # Allocate a persistent receive request - for n in 1:length(recv_bufs) - MPI.Recv_init( - recv_bufs[n], - ctx.mpicomm, - recv_reqs[n]; - source = recv_ranks[n], - tag = tag, - ) - end - # Allocate a persistent send request - for n in 1:length(send_bufs) - MPI.Send_init( - send_bufs[n], - ctx.mpicomm, - send_reqs[n]; - dest = send_ranks[n], - tag = tag, - ) - end - MPIPersistentSendRecvGraphContext(args...) - else - MPISendRecvGraphContext(args...) - end -end - -graph_context( - ctx::MPICommsContext, - send_array, - send_lengths, - send_pids, - recv_array, - recv_lengths, - recv_pids; - persistent::Bool = true, -) = graph_context( - ctx, - send_array, - send_lengths, - send_pids, - recv_array, - recv_lengths, - recv_pids, - persistent ? MPIPersistentSendRecvGraphContext : MPISendRecvGraphContext, -) - -function start(ghost::MPISendRecvGraphContext; dependencies = nothing) - if !all(MPI.isnull, ghost.recv_reqs) - error("Must finish() before next start()") - end - # post receives - for n in 1:length(ghost.recv_bufs) - MPI.Irecv!( - ghost.recv_bufs[n], - ghost.recv_ranks[n], - ghost.tag, - ghost.ctx.mpicomm, - ghost.recv_reqs[n], - ) - end - # post sends - for n in 1:length(ghost.send_bufs) - MPI.Isend( - ghost.send_bufs[n], - ghost.send_ranks[n], - ghost.tag, - ghost.ctx.mpicomm, - ghost.send_reqs[n], - ) - end -end - -function start(ghost::MPIPersistentSendRecvGraphContext; dependencies = nothing) - MPI.Startall(ghost.recv_reqs) # post receives - MPI.Startall(ghost.send_reqs) # post sends -end - -function progress( - ghost::Union{MPISendRecvGraphContext, MPIPersistentSendRecvGraphContext}, -) - if isdefined(MPI, :MPI_ANY_SOURCE) # < v0.20 - MPI.Iprobe(MPI.MPI_ANY_SOURCE, ghost.tag, ghost.ctx.mpicomm) - else # >= v0.20 - MPI.Iprobe(MPI.ANY_SOURCE, ghost.tag, ghost.ctx.mpicomm) - end -end - -function finish( - ghost::Union{MPISendRecvGraphContext, MPIPersistentSendRecvGraphContext}; - dependencies = nothing, -) - # wait on previous receives - MPI.Waitall(ghost.recv_reqs) - # ensure that sends have completed - # TODO: these could be moved to start()? but we would need to add a finalizer to make sure they complete. - MPI.Waitall(ghost.send_reqs) -end +function MPICommsContext end diff --git a/test/Project.toml b/test/Project.toml index 309c8759..6dec72bc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,5 @@ [deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" -MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/basic.jl b/test/basic.jl deleted file mode 100644 index 4e448546..00000000 --- a/test/basic.jl +++ /dev/null @@ -1,219 +0,0 @@ - -using Test -using ClimaComms - -context = ClimaComms.context() -pid, nprocs = ClimaComms.init(context) -device = ClimaComms.device(context) -AT = ClimaComms.array_type(device) - -if ClimaComms.iamroot(context) - @info "Running test" context device AT -end - -if haskey(ENV, "CLIMACOMMS_TEST_DEVICE") - if ENV["CLIMACOMMS_TEST_DEVICE"] == "CPU" - @test device isa ClimaComms.AbstractCPUDevice - elseif ENV["CLIMACOMMS_TEST_DEVICE"] == "CPUSingleThreaded" - @test device isa ClimaComms.CPUSingleThreaded - elseif ENV["CLIMACOMMS_TEST_DEVICE"] == "CPUMultiThreaded" - @test device isa ClimaComms.CPUMultiThreaded - elseif ENV["CLIMACOMMS_TEST_DEVICE"] == "CUDA" - @test device isa ClimaComms.CUDADevice - end -end - -using SafeTestsets -@safetestset "macro hygiene" begin - include("hygiene.jl") -end - -if context isa ClimaComms.MPICommsContext - graph_opt_list = [(; persistent = true), (; persistent = false)] -else - graph_opt_list = [()] -end -@testset "tree test $graph_opt" for graph_opt in graph_opt_list - for FT in (Float32, Float64) - # every process communicates with the root - if ClimaComms.iamroot(context) - # send 2*n items to the nth pid, receive 3*n - sendpids = collect(2:nprocs) - sendlengths = [2 * dest for dest in sendpids] - sendarray = AT(fill(zero(FT), (2, sum(sendpids)))) - recvpids = collect(2:nprocs) - recvlengths = [3 * dest for dest in recvpids] - recvarray = AT(fill(zero(FT), (3, sum(recvpids)))) - else - # send 3*pid items to the 1st pid, receive 2*pid - sendpids = [1] - sendlengths = [3 * pid] - sendarray = AT(fill(zero(FT), (3, pid))) - recvpids = [1] - recvlengths = [2 * pid] - recvarray = AT(fill(zero(FT), (2, pid))) - end - graph_context = ClimaComms.graph_context( - context, - sendarray, - sendlengths, - sendpids, - recvarray, - recvlengths, - recvpids; - graph_opt..., - ) - - # 1) fill buffers with pid - fill!(sendarray, FT(pid)) - - ClimaComms.start(graph_context) - ClimaComms.progress(graph_context) - ClimaComms.finish(graph_context) - - if ClimaComms.iamroot(context) - offset = 0 - for s in 2:nprocs - @test all( - ==(FT(s)), - view(recvarray, :, (offset + 1):(offset + s)), - ) - offset += s - end - else - @test all(==(FT(1)), recvarray) - end - - # 2) send everything back - if ClimaComms.iamroot(context) - sendarray .= view(recvarray, 1:2, :) - else - sendarray .= FT(1) - end - - ClimaComms.start(graph_context) - ClimaComms.progress(graph_context) - ClimaComms.finish(graph_context) - - @test all(==(FT(pid)), recvarray) - end -end - -@testset "linear test $graph_opt" for graph_opt in graph_opt_list - for FT in (Float32, Float64) - # send 2 values up - if pid < nprocs - sendpids = Int[pid + 1] - sendlengths = Int[2] - sendarray = AT(fill(zero(FT), (2,))) - else - sendpids = Int[] - sendlengths = Int[] - sendarray = AT(fill(zero(FT), (0,))) - end - if pid > 1 - recvpids = Int[pid - 1] - recvlengths = Int[2] - recvarray = AT(fill(zero(FT), (2,))) - else - recvpids = Int[] - recvlengths = Int[] - recvarray = AT(fill(zero(FT), (0,))) - end - graph_context = ClimaComms.graph_context( - context, - sendarray, - sendlengths, - sendpids, - recvarray, - recvlengths, - recvpids; - graph_opt..., - ) - - # 1) send pid - if pid < nprocs - sendarray .= FT(pid) - end - ClimaComms.start(graph_context) - ClimaComms.progress(graph_context) - ClimaComms.finish(graph_context) - - if pid > 1 - @test all(==(FT(pid - 1)), recvarray) - end - - # 2) send next - if 1 < pid < nprocs - sendarray .= recvarray - end - - ClimaComms.start(graph_context) - ClimaComms.progress(graph_context) - ClimaComms.finish(graph_context) - - if pid > 2 - @test all(==(FT(pid - 2)), recvarray) - end - end -end - -@testset "gather" begin - for FT in (Float32, Float64) - local_array = AT(fill(FT(pid), (3, pid))) - gathered = ClimaComms.gather(context, local_array) - if ClimaComms.iamroot(context) - @test gathered isa AT - @test gathered == - AT(reduce(hcat, [fill(FT(i), (3, i)) for i in 1:nprocs])) - else - @test isnothing(gathered) - end - end -end - -@testset "reduce/reduce!/allreduce" begin - for FT in (Float32, Float64) - pidsum = div(nprocs * (nprocs + 1), 2) - - sendrecvbuf = AT(fill(FT(pid), 3)) - ClimaComms.allreduce!(context, sendrecvbuf, +) - @test sendrecvbuf == AT(fill(FT(pidsum), 3)) - - sendrecvbuf = AT(fill(FT(pid), 3)) - ClimaComms.reduce!(context, sendrecvbuf, +) - if ClimaComms.iamroot(context) - @test sendrecvbuf == AT(fill(FT(pidsum), 3)) - end - - sendbuf = AT(fill(FT(pid), 2)) - recvbuf = AT(zeros(FT, 2)) - ClimaComms.reduce!(context, sendbuf, recvbuf, +) - if ClimaComms.iamroot(context) - @test recvbuf == AT(fill(FT(pidsum), 2)) - end - - sendbuf = AT(fill(FT(pid), 2)) - recvbuf = AT(zeros(FT, 2)) - ClimaComms.allreduce!(context, sendbuf, recvbuf, +) - @test recvbuf == AT(fill(FT(pidsum), 2)) - - recvval = ClimaComms.allreduce(context, FT(pid), +) - @test recvval == FT(pidsum) - - recvval = ClimaComms.reduce(context, FT(pid), +) - if ClimaComms.iamroot(context) - @test recvval == FT(pidsum) - end - end -end - -@testset "bcast" begin - @test ClimaComms.bcast(context, ClimaComms.iamroot(context)) - @test ClimaComms.bcast(context, pid) == 1 - @test ClimaComms.bcast(context, "root pid is $pid") == "root pid is 1" - @test ClimaComms.bcast(context, AT(fill(Float32(pid), 3))) == - AT(fill(Float32(1), 3)) - @test ClimaComms.bcast(context, AT(fill(Float64(pid), 3))) == - AT(fill(Float64(1), 3)) -end diff --git a/test/runtests.jl b/test/runtests.jl index 0950959e..e0a085d8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,227 @@ -using CUDA, CUDA_Runtime_jll +using Test +using ClimaComms -include("basic.jl") +if haskey(ENV, "CLIMACOMMS_TEST_DEVICE") && + ENV["CLIMACOMMS_TEST_DEVICE"] == "CUDA" + import CUDA + @test ClimaComms.cuda_ext_available() +end + +import MPI +@test ClimaComms.mpi_ext_available() -using MPI, Test +context = ClimaComms.context() +pid, nprocs = ClimaComms.init(context) +device = ClimaComms.device(context) +AT = ClimaComms.array_type(device) -function runmpi(file; ntasks = 1) - Base.run( - `$(MPI.mpiexec()) -n $ntasks $(Base.julia_cmd()) --startup-file=no --project=$(Base.active_project()) $file`, - ) +if ClimaComms.iamroot(context) + @info "Running test" context device AT end -@test success(runmpi(joinpath(@__DIR__, "basic.jl"), ntasks = 2)) +if haskey(ENV, "CLIMACOMMS_TEST_DEVICE") + if ENV["CLIMACOMMS_TEST_DEVICE"] == "CPU" + @test device isa ClimaComms.AbstractCPUDevice + elseif ENV["CLIMACOMMS_TEST_DEVICE"] == "CPUSingleThreaded" + @test device isa ClimaComms.CPUSingleThreaded + elseif ENV["CLIMACOMMS_TEST_DEVICE"] == "CPUMultiThreaded" + @test device isa ClimaComms.CPUMultiThreaded + elseif ENV["CLIMACOMMS_TEST_DEVICE"] == "CUDA" + @test device isa ClimaComms.CUDADevice + end +end + +using SafeTestsets +@safetestset "macro hygiene" begin + include("hygiene.jl") +end + +if context isa ClimaComms.MPICommsContext + graph_opt_list = [(; persistent = true), (; persistent = false)] +else + graph_opt_list = [()] +end +@testset "tree test $graph_opt" for graph_opt in graph_opt_list + for FT in (Float32, Float64) + # every process communicates with the root + if ClimaComms.iamroot(context) + # send 2*n items to the nth pid, receive 3*n + sendpids = collect(2:nprocs) + sendlengths = [2 * dest for dest in sendpids] + sendarray = AT(fill(zero(FT), (2, sum(sendpids)))) + recvpids = collect(2:nprocs) + recvlengths = [3 * dest for dest in recvpids] + recvarray = AT(fill(zero(FT), (3, sum(recvpids)))) + else + # send 3*pid items to the 1st pid, receive 2*pid + sendpids = [1] + sendlengths = [3 * pid] + sendarray = AT(fill(zero(FT), (3, pid))) + recvpids = [1] + recvlengths = [2 * pid] + recvarray = AT(fill(zero(FT), (2, pid))) + end + graph_context = ClimaComms.graph_context( + context, + sendarray, + sendlengths, + sendpids, + recvarray, + recvlengths, + recvpids; + graph_opt..., + ) + + # 1) fill buffers with pid + fill!(sendarray, FT(pid)) + + ClimaComms.start(graph_context) + ClimaComms.progress(graph_context) + ClimaComms.finish(graph_context) + + if ClimaComms.iamroot(context) + offset = 0 + for s in 2:nprocs + @test all( + ==(FT(s)), + view(recvarray, :, (offset + 1):(offset + s)), + ) + offset += s + end + else + @test all(==(FT(1)), recvarray) + end + + # 2) send everything back + if ClimaComms.iamroot(context) + sendarray .= view(recvarray, 1:2, :) + else + sendarray .= FT(1) + end + + ClimaComms.start(graph_context) + ClimaComms.progress(graph_context) + ClimaComms.finish(graph_context) + + @test all(==(FT(pid)), recvarray) + end +end + +@testset "linear test $graph_opt" for graph_opt in graph_opt_list + for FT in (Float32, Float64) + # send 2 values up + if pid < nprocs + sendpids = Int[pid + 1] + sendlengths = Int[2] + sendarray = AT(fill(zero(FT), (2,))) + else + sendpids = Int[] + sendlengths = Int[] + sendarray = AT(fill(zero(FT), (0,))) + end + if pid > 1 + recvpids = Int[pid - 1] + recvlengths = Int[2] + recvarray = AT(fill(zero(FT), (2,))) + else + recvpids = Int[] + recvlengths = Int[] + recvarray = AT(fill(zero(FT), (0,))) + end + graph_context = ClimaComms.graph_context( + context, + sendarray, + sendlengths, + sendpids, + recvarray, + recvlengths, + recvpids; + graph_opt..., + ) + + # 1) send pid + if pid < nprocs + sendarray .= FT(pid) + end + ClimaComms.start(graph_context) + ClimaComms.progress(graph_context) + ClimaComms.finish(graph_context) + + if pid > 1 + @test all(==(FT(pid - 1)), recvarray) + end + + # 2) send next + if 1 < pid < nprocs + sendarray .= recvarray + end + + ClimaComms.start(graph_context) + ClimaComms.progress(graph_context) + ClimaComms.finish(graph_context) + + if pid > 2 + @test all(==(FT(pid - 2)), recvarray) + end + end +end + +@testset "gather" begin + for FT in (Float32, Float64) + local_array = AT(fill(FT(pid), (3, pid))) + gathered = ClimaComms.gather(context, local_array) + if ClimaComms.iamroot(context) + @test gathered isa AT + @test gathered == + AT(reduce(hcat, [fill(FT(i), (3, i)) for i in 1:nprocs])) + else + @test isnothing(gathered) + end + end +end + +@testset "reduce/reduce!/allreduce" begin + for FT in (Float32, Float64) + pidsum = div(nprocs * (nprocs + 1), 2) + + sendrecvbuf = AT(fill(FT(pid), 3)) + ClimaComms.allreduce!(context, sendrecvbuf, +) + @test sendrecvbuf == AT(fill(FT(pidsum), 3)) + + sendrecvbuf = AT(fill(FT(pid), 3)) + ClimaComms.reduce!(context, sendrecvbuf, +) + if ClimaComms.iamroot(context) + @test sendrecvbuf == AT(fill(FT(pidsum), 3)) + end + + sendbuf = AT(fill(FT(pid), 2)) + recvbuf = AT(zeros(FT, 2)) + ClimaComms.reduce!(context, sendbuf, recvbuf, +) + if ClimaComms.iamroot(context) + @test recvbuf == AT(fill(FT(pidsum), 2)) + end + + sendbuf = AT(fill(FT(pid), 2)) + recvbuf = AT(zeros(FT, 2)) + ClimaComms.allreduce!(context, sendbuf, recvbuf, +) + @test recvbuf == AT(fill(FT(pidsum), 2)) + + recvval = ClimaComms.allreduce(context, FT(pid), +) + @test recvval == FT(pidsum) + + recvval = ClimaComms.reduce(context, FT(pid), +) + if ClimaComms.iamroot(context) + @test recvval == FT(pidsum) + end + end +end + +@testset "bcast" begin + @test ClimaComms.bcast(context, ClimaComms.iamroot(context)) + @test ClimaComms.bcast(context, pid) == 1 + @test ClimaComms.bcast(context, "root pid is $pid") == "root pid is 1" + @test ClimaComms.bcast(context, AT(fill(Float32(pid), 3))) == + AT(fill(Float32(1), 3)) + @test ClimaComms.bcast(context, AT(fill(Float64(pid), 3))) == + AT(fill(Float64(1), 3)) +end