Skip to content

Commit

Permalink
Forbid divergent execution of work-group barriers
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Feb 14, 2025
1 parent ed2ee63 commit b58c830
Showing 1 changed file with 6 additions and 24 deletions.
30 changes: 6 additions & 24 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,8 @@ function transform_gpu!(def, constargs, force_inbounds, unsafe_indices)
end

struct WorkgroupLoop
indices::Vector{Any}
stmts::Vector{Any}
allocations::Vector{Any}
private_allocations::Vector{Any}
private::Set{Symbol}
terminated_in_sync::Bool
end

Expand All @@ -111,26 +108,18 @@ function find_sync(stmt)
end

# TODO proper handling of LineInfo
function split(
stmts,
indices = Any[], private = Set{Symbol}(),
)
function split(stmts)

Check warning on line 111 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L111

Added line #L111 was not covered by tests
# 1. Split the code into blocks separated by `@synchronize`
# 2. Aggregate `@index` expressions
# 3. Hoist allocations
# 4. Hoist uniforms

current = Any[]
allocations = Any[]
private_allocations = Any[]
new_stmts = Any[]
for stmt in stmts
has_sync = find_sync(stmt)
if has_sync
loop = WorkgroupLoop(deepcopy(indices), current, allocations, private_allocations, deepcopy(private), is_sync(stmt))
loop = WorkgroupLoop(current, allocations, is_sync(stmt))

Check warning on line 120 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L120

Added line #L120 was not covered by tests
push!(new_stmts, emit(loop))
allocations = Any[]
private_allocations = Any[]
current = Any[]
is_sync(stmt) && continue

Expand All @@ -142,7 +131,7 @@ function split(
function recurse(expr::Expr)
expr = unblock(expr)
if is_scope_construct(expr) && any(find_sync, expr.args)
new_args = unblock(split(expr.args, deepcopy(indices), deepcopy(private)))
new_args = unblock(split(expr.args))

Check warning on line 134 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L134

Added line #L134 was not covered by tests
return Expr(expr.head, new_args...)
else
return Expr(expr.head, map(recurse, expr.args)...)
Expand All @@ -156,14 +145,10 @@ function split(
push!(allocations, stmt)
continue
elseif @capture(stmt, @private lhs_ = rhs_)
push!(private, lhs)
push!(private_allocations, :($lhs = $rhs))
push!(allocations, :($lhs = $rhs))

Check warning on line 148 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L148

Added line #L148 was not covered by tests
continue
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
if @capture(rhs, @index(args__))
push!(indices, stmt)
continue
elseif @capture(rhs, @localmem(args__) | @uniform(args__))
if @capture(rhs, @localmem(args__) | @uniform(args__))

Check warning on line 151 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L151

Added line #L151 was not covered by tests
push!(allocations, stmt)
continue
elseif @capture(rhs, @private(T_, dims_))
Expand All @@ -175,7 +160,6 @@ function split(
end
alloc = :($Scratchpad(__ctx__, $T, Val($dims)))
push!(allocations, :($lhs = $alloc))
push!(private, lhs)
continue
end
end
Expand All @@ -185,7 +169,7 @@ function split(

# everything since the last `@synchronize`
if !isempty(current)
loop = WorkgroupLoop(deepcopy(indices), current, allocations, private_allocations, deepcopy(private), false)
loop = WorkgroupLoop(current, allocations, false)

Check warning on line 172 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L172

Added line #L172 was not covered by tests
push!(new_stmts, emit(loop))
end
return new_stmts
Expand All @@ -197,9 +181,7 @@ function emit(loop)
body = Expr(:block, loop.stmts...)
loopexpr = quote
$(loop.allocations...)
$(loop.private_allocations...)
if __active_lane__
$(loop.indices...)
$(unblock(body))
end
end
Expand Down

0 comments on commit b58c830

Please sign in to comment.