From 1163b3252b24af21feb7881957db59f5675bf6ca Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 28 Jan 2025 15:48:57 +0100 Subject: [PATCH] Forbid divergent execution of work-group barriers --- src/macros.jl | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/src/macros.jl b/src/macros.jl index 2f184dd7..58f4cac3 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -86,11 +86,8 @@ function transform_gpu!(def, constargs, force_inbounds, unsafe_indices) end struct WorkgroupLoop - indices::Vector{Any} stmts::Vector{Any} allocations::Vector{Any} - private_allocations::Vector{Any} - private::Set{Symbol} terminated_in_sync::Bool end @@ -111,26 +108,18 @@ function find_sync(stmt) end # TODO proper handling of LineInfo -function split( - stmts, - indices = Any[], private = Set{Symbol}(), - ) +function split(stmts) # 1. Split the code into blocks separated by `@synchronize` - # 2. Aggregate `@index` expressions - # 3. Hoist allocations - # 4. Hoist uniforms current = Any[] allocations = Any[] - private_allocations = Any[] new_stmts = Any[] for stmt in stmts has_sync = find_sync(stmt) if has_sync - loop = WorkgroupLoop(deepcopy(indices), current, allocations, private_allocations, deepcopy(private), is_sync(stmt)) + loop = WorkgroupLoop(current, allocations, is_sync(stmt)) push!(new_stmts, emit(loop)) allocations = Any[] - private_allocations = Any[] current = Any[] is_sync(stmt) && continue @@ -142,7 +131,7 @@ function split( function recurse(expr::Expr) expr = unblock(expr) if is_scope_construct(expr) && any(find_sync, expr.args) - new_args = unblock(split(expr.args, deepcopy(indices), deepcopy(private))) + new_args = unblock(split(expr.args)) return Expr(expr.head, new_args...) else return Expr(expr.head, map(recurse, expr.args)...) @@ -156,14 +145,10 @@ function split( push!(allocations, stmt) continue elseif @capture(stmt, @private lhs_ = rhs_) - push!(private, lhs) - push!(private_allocations, :($lhs = $rhs)) + push!(allocations, :($lhs = $rhs)) continue elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_)) - if @capture(rhs, @index(args__)) - push!(indices, stmt) - continue - elseif @capture(rhs, @localmem(args__) | @uniform(args__)) + if @capture(rhs, @localmem(args__) | @uniform(args__)) push!(allocations, stmt) continue elseif @capture(rhs, @private(T_, dims_)) @@ -175,7 +160,6 @@ function split( end alloc = :($Scratchpad(__ctx__, $T, Val($dims))) push!(allocations, :($lhs = $alloc)) - push!(private, lhs) continue end end @@ -185,7 +169,7 @@ function split( # everything since the last `@synchronize` if !isempty(current) - loop = WorkgroupLoop(deepcopy(indices), current, allocations, private_allocations, deepcopy(private), false) + loop = WorkgroupLoop(current, allocations, false) push!(new_stmts, emit(loop)) end return new_stmts @@ -197,9 +181,7 @@ function emit(loop) body = Expr(:block, loop.stmts...) loopexpr = quote $(loop.allocations...) - $(loop.private_allocations...) if __active_lane__ - $(loop.indices...) $(unblock(body)) end end