From 8ad24969671b1e7f27693fd3eaf98f5b99a1f9b9 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Wed, 5 Jun 2024 16:11:01 -0400 Subject: [PATCH] Define allowscalar --- NEWS.md | 7 +++++++ docs/src/index.md | 1 + ext/ClimaCommsCUDAExt.jl | 2 ++ src/devices.jl | 24 ++++++++++++++++++++++++ test/runtests.jl | 10 ++++++++++ 5 files changed, 44 insertions(+) diff --git a/NEWS.md b/NEWS.md index 2d90711a..8e9eda21 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,14 @@ ClimaComms.jl Release Notes ======================== +v0.6.2 +------- + +- We added a device-agnostic `allowscalar(f, ::AbstractDevice, args...; kwargs...)` to further assist in making CUDA an extension. + v0.6.1 +------- + - Macros have been refactored to hopefully fix some code loading issues. v0.6.0 diff --git a/docs/src/index.md b/docs/src/index.md index 564031a8..f5996e74 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -27,6 +27,7 @@ ClimaComms.CUDADevice ClimaComms.device ClimaComms.device_functional ClimaComms.array_type +ClimaComms.allowscalar ClimaComms.@threaded ClimaComms.@time ClimaComms.@elapsed diff --git a/ext/ClimaCommsCUDAExt.jl b/ext/ClimaCommsCUDAExt.jl index 3c6cf1d0..fff76b94 100644 --- a/ext/ClimaCommsCUDAExt.jl +++ b/ext/ClimaCommsCUDAExt.jl @@ -14,6 +14,8 @@ function ClimaComms.device_functional(::ClimaComms.CUDADevice) end ClimaComms.array_type(::ClimaComms.CUDADevice) = CUDA.CuArray +ClimaComms.allowscalar(f, ::ClimaComms.CUDADevice, args...; kwargs...) = + CUDA.@allowscalar f(args...; kwargs...) # Extending ClimaComms methods that operate on expressions (cannot use dispatch here) ClimaComms.cuda_sync(expr) = CUDA.@sync expr diff --git a/src/devices.jl b/src/devices.jl index 0ac5e882..bd77d089 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -265,3 +265,27 @@ macro cuda_sync(device, expr) end end) end + +""" + allowscalar(f, ::AbstractDevice, args...; kwargs...) + +Device-flexible version of `CUDA.@allowscalar`. + +Lowers to +```julia +f(args...) +``` +for CPU devices and +```julia +CUDA.@allowscalar f(args...) +``` +for CUDA devices. + +This is usefully written with closures via +```julia +allowscalar(device) do + f() +end +``` +""" +allowscalar(f, ::AbstractDevice, args...; kwargs...) = f(args...; kwargs...) diff --git a/test/runtests.jl b/test/runtests.jl index 5574fdd3..2f4b6c60 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -217,3 +217,13 @@ end @test ClimaComms.bcast(context, AT(fill(Float64(pid), 3))) == AT(fill(Float64(1), 3)) end + +@testset "allowscalar" begin + a = AT(rand(3)) + local x + ClimaComms.allowscalar(device) do + x = a[1] + end + device isa ClimaComms.CUDADevice && @test_throws ErrorException a[1] + @test x == Array(a)[1] +end