From ea983ecd56aeeefb1db2a28d0516592e94823130 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Thu, 13 Jul 2023 10:37:22 -0700 Subject: [PATCH] Improve hygiene test Fix macro hygiene Bump patch version --- Project.toml | 2 +- src/devices.jl | 16 ++++++++++------ test/hygiene.jl | 25 ++++++++++++++++--------- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index 10ca2c49..985a1512 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ClimaComms" uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" authors = ["CliMA Contributors "] -version = "0.5.2" +version = "0.5.3" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/devices.jl b/src/devices.jl index 4e0dcae7..1133b483 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -97,13 +97,17 @@ that this is statically inferred. - https://discourse.julialang.org/t/threads-threads-with-one-thread-how-to-remove-the-overhead/58435 - https://discourse.julialang.org/t/overhead-of-threads-threads/53964 """ -macro threaded(device, expr) - return quote +macro threaded(device, loop) + quote if $(esc(device)) isa CPUMultiThreaded - Threads.@threads $(expr) + Threads.@threads $(Expr( + loop.head, + Expr(loop.args[1].head, esc.(loop.args[1].args)...), + esc(loop.args[2]), + )) else @assert $(esc(device)) isa AbstractDevice - $(esc(expr)) + $(esc(loop)) end end end @@ -152,10 +156,10 @@ for CUDA devices. macro elapsed(device, expr) return quote if $(esc(device)) isa CUDADevice - CUDA.@elapsed $(expr) + CUDA.@elapsed $(esc(expr)) else @assert $(esc(device)) isa AbstractDevice - Base.@elapsed $(expr) + Base.@elapsed $(esc(expr)) end end end diff --git a/test/hygiene.jl b/test/hygiene.jl index 216de5ea..6465e2b2 100644 --- a/test/hygiene.jl +++ b/test/hygiene.jl @@ -1,13 +1,20 @@ import ClimaComms as CC -dev = CC.device() -CC.@threaded dev for i in 1:2 - 1 -end +function test_macro_hyhiene(dev) + n = 3 # tests that we can reach variables defined in scope + CC.@threaded dev for i in 1:n + if Threads.nthreads() > 1 + @show Threads.threadid() + end + end -CC.@time dev for i in 1:2 - sin.(rand(10)) -end + CC.@time dev for i in 1:n + sin.(rand(10)) + end -CC.@elapsed dev for i in 1:2 - sin.(rand(10)) + CC.@elapsed dev for i in 1:n + sin.(rand(10)) + end end +dev = CC.device() + +test_macro_hyhiene(dev)