Skip to content

Commit

Permalink
Merge pull request #161 from omlins/hide
Browse files Browse the repository at this point in the history
Do minor hide communication improvement
  • Loading branch information
omlins authored Aug 12, 2024
2 parents 26eacdc + 1b22a7f commit e3f7e13
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 31 deletions.
102 changes: 81 additions & 21 deletions src/ParallelKernel/hide_communication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,35 @@ function hide_communication_gpu(ranges_outer::Union{Symbol,Expr}, ranges_inner::
push!(compcalls_outer, :(@parallel_async $ranges_outer[i] stream=ParallelStencil.ParallelKernel.@get_priority_stream(i) $(kwargs...) $compkernelcall)) #NOTE: it cannot directly go to ParallelStencil.ParallelKernel.@parallel_async as else it cannot honour ParallelStencil args as memopt (fixing it to ParallelStencil is also not possible as it assumes, else the ParalellKernel hide_communication unit tests fail).
push!(compcalls_inner, :(@parallel_async $ranges_inner[i] stream=ParallelStencil.ParallelKernel.@get_stream(i) $(kwargs...) $compkernelcall)) #NOTE: ...
end
bc_and_commcalls = process_bc_and_commcalls(bc_and_commcalls)
quote
for i in 1:length($ranges_outer)
$(compcalls_outer...)
bc_and_commcalls = flatten(process_bc_and_commcalls(bc_and_commcalls))
if comm_is_splitted(bc_and_commcalls)
bc_and_commcalls_z, bc_and_commcalls_xy = split_bc_and_commcalls(bc_and_commcalls)
quote
for i in 1:length($ranges_outer)
$(compcalls_outer...)
end
for i in 2:3 ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end # NOTE: synchronize the streams of the z-boundary computations (assumed to be stream 2 and 3 - to be in agreement with get_ranges_outer)
$bc_and_commcalls_z
for i in 1:length($ranges_inner)
$(compcalls_inner...)
end
for i in 1:1 ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end
for i in 4:length($ranges_outer) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end
$bc_and_commcalls_xy
for i in 1:length($ranges_inner) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_stream(i)); end
end
for i in 1:length($ranges_inner)
$(compcalls_inner...)
else
quote
for i in 1:length($ranges_outer)
$(compcalls_outer...)
end
for i in 1:length($ranges_inner)
$(compcalls_inner...)
end
for i in 1:length($ranges_outer) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end
$bc_and_commcalls
for i in 1:length($ranges_inner) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_stream(i)); end
end
for i in 1:length($ranges_outer) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end
$bc_and_commcalls
for i in 1:length($ranges_inner) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_stream(i)); end
end
end

Expand All @@ -169,6 +187,10 @@ end
function hide_communication_gpu(boundary_width::Union{Integer,Symbol,Expr}, block::Expr; computation_calls::Integer=1)
if (computation_calls < 1) @KeywordArgumentError("Invalid keyword argument in @hide_communication: computation_calls must be >= 1.") end
compcalls, bc_and_commcalls = extract_calls(block, computation_calls)

USE_EXPERIMENTAL = false
if (USE_EXPERIMENTAL) bc_and_commcalls = flatten(split_commcalls(bc_and_commcalls)) end

compranges = []
for i in 1:length(compcalls)
parallel_args = extract_args(compcalls[i], Symbol("@parallel"))
Expand Down Expand Up @@ -226,6 +248,41 @@ function process_bc_and_commcalls(block::Expr)
end
end

function split_commcalls(block::Expr)
return postwalk(block) do x
if !(!@capture(x, f_(args__; kwargs__)) && @capture(x, f_(args__)) && f == :update_halo!) return x; end
return :(update_halo!($(args...); dims=(3,)); update_halo!($(args...); dims=(1,2)))
end
end

function comm_is_splitted(block::Expr)
if !is_block(block) return false; end
statements = block.args
has_comm_z = false
has_comm_xy = false
for statement in statements
if @capture(statement, f_(args__; kwarg_))
if !has_comm_z && @capture(kwarg, dims=(3,)) has_comm_z = true
elseif has_comm_z && @capture(kwarg, dims=(1,2)) has_comm_xy = true
end
end
end
return has_comm_z && has_comm_xy
end

function split_bc_and_commcalls(block::Expr)
if !is_block(block) @ModuleInternalError("expression is not a block; a block with at least two statements for communication is expected (obtained: $block)") end
statements = block.args
comm_z_pos = -1
for i in length(statements):-1:1
if (@capture(statements[i], f_(args__; kwarg_)) && @capture(kwarg, dims=(3,))) comm_z_pos = i; break; end
end
if (comm_z_pos < 1) @ModuleInternalError("no communication statement with dims=(3,) found in the block.") end
bc_and_commcalls_z = quote $(statements[1:comm_z_pos]...) end
bc_and_commcalls_xy = quote $(statements[comm_z_pos+1:end]...) end
return bc_and_commcalls_z, bc_and_commcalls_xy
end


## FUNCTIONS TO GET INNER AND OUTER RANGES AND TO PROMOTE BOUNDARY_WIDTH TO 3D

Expand All @@ -247,22 +304,25 @@ function get_ranges_outer(boundary_width, ranges::RANGES_TYPE...)
ms = length.(ranges)
bw = boundary_width
if ms[3] > 1 # 3D
ranges_outer = ((1:ms[1], 1:ms[2], 1:bw[3]),
(1:ms[1], 1:ms[2], ms[3]-bw[3]+1:ms[3]),
(1:ms[1], 1:bw[2], bw[3]+1:ms[3]-bw[3]),
(1:ms[1], ms[2]-bw[2]+1:ms[2], bw[3]+1:ms[3]-bw[3]),
(1:bw[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]),
(ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]),
ranges_outer = (
(1:bw[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]), # 5
(1:ms[1], 1:ms[2], 1:bw[3]), # 1
(1:ms[1], 1:ms[2], ms[3]-bw[3]+1:ms[3]), # 2
(ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]), # 6
(1:ms[1], 1:bw[2], bw[3]+1:ms[3]-bw[3]), # 3
(1:ms[1], ms[2]-bw[2]+1:ms[2], bw[3]+1:ms[3]-bw[3]), # 4
)
elseif ms[2] > 1 # 2D
ranges_outer = ((1:ms[1], 1:bw[2], 1:1),
(1:ms[1], ms[2]-bw[2]+1:ms[2], 1:1),
(1:bw[1], bw[2]+1:ms[2]-bw[2], 1:1),
(ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], 1:1),
ranges_outer = (
(ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], 1:1), # 4
(1:ms[1], 1:bw[2], 1:1), # 1
(1:ms[1], ms[2]-bw[2]+1:ms[2], 1:1), # 2
(1:bw[1], bw[2]+1:ms[2]-bw[2], 1:1), # 3
)
elseif ms[1] > 1 # 1D
ranges_outer = ((1:bw[1], 1:1, 1:1),
(ms[1]-bw[1]+1:ms[1], 1:1, 1:1),
ranges_outer = (
(ms[1]-bw[1]+1:ms[1], 1:1, 1:1), # 2
(1:bw[1], 1:1, 1:1), # 1
)
else
@ModuleInternalError("invalid argument 'ranges'.")
Expand Down
4 changes: 2 additions & 2 deletions src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ end

## @SYNCHRONIZE FUNCTIONS

synchronize_cuda(args::Union{Symbol,Expr}...) = :(CUDA.synchronize($(args...)))
synchronize_amdgpu(args::Union{Symbol,Expr}...) = :(AMDGPU.synchronize($(args...)))
synchronize_cuda(args::Union{Symbol,Expr}...) = :(CUDA.synchronize($(args...); blocking=true))
synchronize_amdgpu(args::Union{Symbol,Expr}...) = :(AMDGPU.synchronize($(args...); blocking=true))
synchronize_threads(args::Union{Symbol,Expr}...) = :(begin end)
synchronize_polyester(args::Union{Symbol,Expr}...) = :(begin end)

Expand Down
12 changes: 6 additions & 6 deletions test/ParallelKernel/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import Enzyme
@static if $package == $PKG_CUDA
call = @prettystring(1, @parallel f(A))
@test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call)
@test occursin("CUDA.synchronize(CUDA.stream())", call)
@test occursin("CUDA.synchronize(CUDA.stream(); blocking = true)", call)
call = @prettystring(1, @parallel ranges f(A))
@test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
call = @prettystring(1, @parallel nblocks nthreads f(A))
Expand All @@ -46,7 +46,7 @@ import Enzyme
elseif $package == $PKG_AMDGPU
call = @prettystring(1, @parallel f(A))
@test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call)
@test occursin("AMDGPU.synchronize(AMDGPU.stream())", call)
@test occursin("AMDGPU.synchronize(AMDGPU.stream(); blocking = true)", call)
call = @prettystring(1, @parallel ranges f(A))
@test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
call = @prettystring(1, @parallel nblocks nthreads f(A))
Expand Down Expand Up @@ -401,11 +401,11 @@ import Enzyme
end;
@testset "@synchronize" begin
@static if $package == $PKG_CUDA
@test @prettystring(1, @synchronize()) == "CUDA.synchronize()"
@test @prettystring(1, @synchronize(mystream)) == "CUDA.synchronize(mystream)"
@test @prettystring(1, @synchronize()) == "CUDA.synchronize(; blocking = true)"
@test @prettystring(1, @synchronize(mystream)) == "CUDA.synchronize(mystream; blocking = true)"
elseif $package == $PKG_AMDGPU
@test @prettystring(1, @synchronize()) == "AMDGPU.synchronize()"
@test @prettystring(1, @synchronize(mystream)) == "AMDGPU.synchronize(mystream)"
@test @prettystring(1, @synchronize()) == "AMDGPU.synchronize(; blocking = true)"
@test @prettystring(1, @synchronize(mystream)) == "AMDGPU.synchronize(mystream; blocking = true)"
end;
end;
@reset_parallel_kernel()
Expand Down
4 changes: 2 additions & 2 deletions test/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import ParallelStencil.@gorgeousexpand
@static if $package == $PKG_CUDA
call = @prettystring(1, @parallel f(A))
@test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call)
@test occursin("CUDA.synchronize(CUDA.stream())", call)
@test occursin("CUDA.synchronize(CUDA.stream(); blocking = true)", call)
call = @prettystring(1, @parallel ranges f(A))
@test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
call = @prettystring(1, @parallel nblocks nthreads f(A))
Expand All @@ -45,7 +45,7 @@ import ParallelStencil.@gorgeousexpand
elseif $package == $PKG_AMDGPU
call = @prettystring(1, @parallel f(A))
@test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call)
@test occursin("AMDGPU.synchronize(AMDGPU.stream())", call)
@test occursin("AMDGPU.synchronize(AMDGPU.stream(); blocking = true)", call)
call = @prettystring(1, @parallel ranges f(A))
@test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
call = @prettystring(1, @parallel nblocks nthreads f(A))
Expand Down

0 comments on commit e3f7e13

Please sign in to comment.