diff --git a/src/ParallelKernel/hide_communication.jl b/src/ParallelKernel/hide_communication.jl index a2758c2c..2360cc27 100644 --- a/src/ParallelKernel/hide_communication.jl +++ b/src/ParallelKernel/hide_communication.jl @@ -147,17 +147,35 @@ function hide_communication_gpu(ranges_outer::Union{Symbol,Expr}, ranges_inner:: push!(compcalls_outer, :(@parallel_async $ranges_outer[i] stream=ParallelStencil.ParallelKernel.@get_priority_stream(i) $(kwargs...) $compkernelcall)) #NOTE: it cannot directly go to ParallelStencil.ParallelKernel.@parallel_async as else it cannot honour ParallelStencil args as memopt (fixing it to ParallelStencil is also not possible as it assumes, else the ParalellKernel hide_communication unit tests fail). push!(compcalls_inner, :(@parallel_async $ranges_inner[i] stream=ParallelStencil.ParallelKernel.@get_stream(i) $(kwargs...) $compkernelcall)) #NOTE: ... end - bc_and_commcalls = process_bc_and_commcalls(bc_and_commcalls) - quote - for i in 1:length($ranges_outer) - $(compcalls_outer...) + bc_and_commcalls = flatten(process_bc_and_commcalls(bc_and_commcalls)) + if comm_is_splitted(bc_and_commcalls) + bc_and_commcalls_z, bc_and_commcalls_xy = split_bc_and_commcalls(bc_and_commcalls) + quote + for i in 1:length($ranges_outer) + $(compcalls_outer...) + end + for i in 2:3 ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end # NOTE: synchronize the streams of the z-boundary computations (assumed to be stream 2 and 3 - to be in agreement with get_ranges_outer) + $bc_and_commcalls_z + for i in 1:length($ranges_inner) + $(compcalls_inner...) + end + for i in 1:1 ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end + for i in 4:length($ranges_outer) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end + $bc_and_commcalls_xy + for i in 1:length($ranges_inner) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_stream(i)); end end - for i in 1:length($ranges_inner) - $(compcalls_inner...) + else + quote + for i in 1:length($ranges_outer) + $(compcalls_outer...) + end + for i in 1:length($ranges_inner) + $(compcalls_inner...) + end + for i in 1:length($ranges_outer) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end + $bc_and_commcalls + for i in 1:length($ranges_inner) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_stream(i)); end end - for i in 1:length($ranges_outer) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end - $bc_and_commcalls - for i in 1:length($ranges_inner) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_stream(i)); end end end @@ -169,6 +187,10 @@ end function hide_communication_gpu(boundary_width::Union{Integer,Symbol,Expr}, block::Expr; computation_calls::Integer=1) if (computation_calls < 1) @KeywordArgumentError("Invalid keyword argument in @hide_communication: computation_calls must be >= 1.") end compcalls, bc_and_commcalls = extract_calls(block, computation_calls) + + USE_EXPERIMENTAL = false + if (USE_EXPERIMENTAL) bc_and_commcalls = flatten(split_commcalls(bc_and_commcalls)) end + compranges = [] for i in 1:length(compcalls) parallel_args = extract_args(compcalls[i], Symbol("@parallel")) @@ -226,6 +248,41 @@ function process_bc_and_commcalls(block::Expr) end end +function split_commcalls(block::Expr) + return postwalk(block) do x + if !(!@capture(x, f_(args__; kwargs__)) && @capture(x, f_(args__)) && f == :update_halo!) return x; end + return :(update_halo!($(args...); dims=(3,)); update_halo!($(args...); dims=(1,2))) + end +end + +function comm_is_splitted(block::Expr) + if !is_block(block) return false; end + statements = block.args + has_comm_z = false + has_comm_xy = false + for statement in statements + if @capture(statement, f_(args__; kwarg_)) + if !has_comm_z && @capture(kwarg, dims=(3,)) has_comm_z = true + elseif has_comm_z && @capture(kwarg, dims=(1,2)) has_comm_xy = true + end + end + end + return has_comm_z && has_comm_xy +end + +function split_bc_and_commcalls(block::Expr) + if !is_block(block) @ModuleInternalError("expression is not a block; a block with at least two statements for communication is expected (obtained: $block)") end + statements = block.args + comm_z_pos = -1 + for i in length(statements):-1:1 + if (@capture(statements[i], f_(args__; kwarg_)) && @capture(kwarg, dims=(3,))) comm_z_pos = i; break; end + end + if (comm_z_pos < 1) @ModuleInternalError("no communication statement with dims=(3,) found in the block.") end + bc_and_commcalls_z = quote $(statements[1:comm_z_pos]...) end + bc_and_commcalls_xy = quote $(statements[comm_z_pos+1:end]...) end + return bc_and_commcalls_z, bc_and_commcalls_xy +end + ## FUNCTIONS TO GET INNER AND OUTER RANGES AND TO PROMOTE BOUNDARY_WIDTH TO 3D @@ -247,22 +304,25 @@ function get_ranges_outer(boundary_width, ranges::RANGES_TYPE...) ms = length.(ranges) bw = boundary_width if ms[3] > 1 # 3D - ranges_outer = ((1:ms[1], 1:ms[2], 1:bw[3]), - (1:ms[1], 1:ms[2], ms[3]-bw[3]+1:ms[3]), - (1:ms[1], 1:bw[2], bw[3]+1:ms[3]-bw[3]), - (1:ms[1], ms[2]-bw[2]+1:ms[2], bw[3]+1:ms[3]-bw[3]), - (1:bw[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]), - (ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]), + ranges_outer = ( + (1:bw[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]), # 5 + (1:ms[1], 1:ms[2], 1:bw[3]), # 1 + (1:ms[1], 1:ms[2], ms[3]-bw[3]+1:ms[3]), # 2 + (ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]), # 6 + (1:ms[1], 1:bw[2], bw[3]+1:ms[3]-bw[3]), # 3 + (1:ms[1], ms[2]-bw[2]+1:ms[2], bw[3]+1:ms[3]-bw[3]), # 4 ) elseif ms[2] > 1 # 2D - ranges_outer = ((1:ms[1], 1:bw[2], 1:1), - (1:ms[1], ms[2]-bw[2]+1:ms[2], 1:1), - (1:bw[1], bw[2]+1:ms[2]-bw[2], 1:1), - (ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], 1:1), + ranges_outer = ( + (ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], 1:1), # 4 + (1:ms[1], 1:bw[2], 1:1), # 1 + (1:ms[1], ms[2]-bw[2]+1:ms[2], 1:1), # 2 + (1:bw[1], bw[2]+1:ms[2]-bw[2], 1:1), # 3 ) elseif ms[1] > 1 # 1D - ranges_outer = ((1:bw[1], 1:1, 1:1), - (ms[1]-bw[1]+1:ms[1], 1:1, 1:1), + ranges_outer = ( + (ms[1]-bw[1]+1:ms[1], 1:1, 1:1), # 2 + (1:bw[1], 1:1, 1:1), # 1 ) else @ModuleInternalError("invalid argument 'ranges'.") diff --git a/src/ParallelKernel/parallel.jl b/src/ParallelKernel/parallel.jl index 20457b06..46c991b3 100644 --- a/src/ParallelKernel/parallel.jl +++ b/src/ParallelKernel/parallel.jl @@ -302,8 +302,8 @@ end ## @SYNCHRONIZE FUNCTIONS -synchronize_cuda(args::Union{Symbol,Expr}...) = :(CUDA.synchronize($(args...))) -synchronize_amdgpu(args::Union{Symbol,Expr}...) = :(AMDGPU.synchronize($(args...))) +synchronize_cuda(args::Union{Symbol,Expr}...) = :(CUDA.synchronize($(args...); blocking=true)) +synchronize_amdgpu(args::Union{Symbol,Expr}...) = :(AMDGPU.synchronize($(args...); blocking=true)) synchronize_threads(args::Union{Symbol,Expr}...) = :(begin end) synchronize_polyester(args::Union{Symbol,Expr}...) = :(begin end) diff --git a/test/ParallelKernel/test_parallel.jl b/test/ParallelKernel/test_parallel.jl index 38302b66..021e69fc 100644 --- a/test/ParallelKernel/test_parallel.jl +++ b/test/ParallelKernel/test_parallel.jl @@ -34,7 +34,7 @@ import Enzyme @static if $package == $PKG_CUDA call = @prettystring(1, @parallel f(A)) @test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call) - @test occursin("CUDA.synchronize(CUDA.stream())", call) + @test occursin("CUDA.synchronize(CUDA.stream(); blocking = true)", call) call = @prettystring(1, @parallel ranges f(A)) @test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call) call = @prettystring(1, @parallel nblocks nthreads f(A)) @@ -46,7 +46,7 @@ import Enzyme elseif $package == $PKG_AMDGPU call = @prettystring(1, @parallel f(A)) @test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call) - @test occursin("AMDGPU.synchronize(AMDGPU.stream())", call) + @test occursin("AMDGPU.synchronize(AMDGPU.stream(); blocking = true)", call) call = @prettystring(1, @parallel ranges f(A)) @test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call) call = @prettystring(1, @parallel nblocks nthreads f(A)) @@ -401,11 +401,11 @@ import Enzyme end; @testset "@synchronize" begin @static if $package == $PKG_CUDA - @test @prettystring(1, @synchronize()) == "CUDA.synchronize()" - @test @prettystring(1, @synchronize(mystream)) == "CUDA.synchronize(mystream)" + @test @prettystring(1, @synchronize()) == "CUDA.synchronize(; blocking = true)" + @test @prettystring(1, @synchronize(mystream)) == "CUDA.synchronize(mystream; blocking = true)" elseif $package == $PKG_AMDGPU - @test @prettystring(1, @synchronize()) == "AMDGPU.synchronize()" - @test @prettystring(1, @synchronize(mystream)) == "AMDGPU.synchronize(mystream)" + @test @prettystring(1, @synchronize()) == "AMDGPU.synchronize(; blocking = true)" + @test @prettystring(1, @synchronize(mystream)) == "AMDGPU.synchronize(mystream; blocking = true)" end; end; @reset_parallel_kernel() diff --git a/test/test_parallel.jl b/test/test_parallel.jl index def2adce..b4d6e2f7 100644 --- a/test/test_parallel.jl +++ b/test/test_parallel.jl @@ -29,7 +29,7 @@ import ParallelStencil.@gorgeousexpand @static if $package == $PKG_CUDA call = @prettystring(1, @parallel f(A)) @test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call) - @test occursin("CUDA.synchronize(CUDA.stream())", call) + @test occursin("CUDA.synchronize(CUDA.stream(); blocking = true)", call) call = @prettystring(1, @parallel ranges f(A)) @test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call) call = @prettystring(1, @parallel nblocks nthreads f(A)) @@ -45,7 +45,7 @@ import ParallelStencil.@gorgeousexpand elseif $package == $PKG_AMDGPU call = @prettystring(1, @parallel f(A)) @test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call) - @test occursin("AMDGPU.synchronize(AMDGPU.stream())", call) + @test occursin("AMDGPU.synchronize(AMDGPU.stream(); blocking = true)", call) call = @prettystring(1, @parallel ranges f(A)) @test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call) call = @prettystring(1, @parallel nblocks nthreads f(A))