Skip to content

Commit

Permalink
Forbid divergent execution of work-group barriers
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Feb 12, 2025
1 parent f88ee87 commit a5f740a
Showing 1 changed file with 5 additions and 20 deletions.
25 changes: 5 additions & 20 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,8 @@ function transform_gpu!(def, constargs, force_inbounds)
end

struct WorkgroupLoop
indicies::Vector{Any}
stmts::Vector{Any}
allocations::Vector{Any}
private_allocations::Vector{Any}
private::Set{Symbol}
terminated_in_sync::Bool
end

Expand All @@ -106,26 +103,18 @@ function find_sync(stmt)
end

# TODO proper handling of LineInfo
function split(
stmts,
indicies = Any[], private = Set{Symbol}(),
)
function split(stmts)

Check warning on line 106 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L106

Added line #L106 was not covered by tests
# 1. Split the code into blocks separated by `@synchronize`
# 2. Aggregate `@index` expressions
# 3. Hoist allocations
# 4. Hoist uniforms

current = Any[]
allocations = Any[]
private_allocations = Any[]
new_stmts = Any[]
for stmt in stmts
has_sync = find_sync(stmt)
if has_sync
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private), is_sync(stmt))
loop = WorkgroupLoop(current, allocations, is_sync(stmt))

Check warning on line 115 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L115

Added line #L115 was not covered by tests
push!(new_stmts, emit(loop))
allocations = Any[]
private_allocations = Any[]
current = Any[]
is_sync(stmt) && continue

Expand All @@ -137,7 +126,7 @@ function split(
function recurse(expr::Expr)
expr = unblock(expr)
if is_scope_construct(expr) && any(find_sync, expr.args)
new_args = unblock(split(expr.args, deepcopy(indicies), deepcopy(private)))
new_args = unblock(split(expr.args))

Check warning on line 129 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L129

Added line #L129 was not covered by tests
return Expr(expr.head, new_args...)
else
return Expr(expr.head, map(recurse, expr.args)...)
Expand All @@ -151,8 +140,7 @@ function split(
push!(allocations, stmt)
continue
elseif @capture(stmt, @private lhs_ = rhs_)
push!(private, lhs)
push!(private_allocations, :($lhs = $rhs))
push!(allocations, :($lhs = $rhs))

Check warning on line 143 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L143

Added line #L143 was not covered by tests
continue
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
if @capture(rhs, @index(args__))
Expand All @@ -170,7 +158,6 @@ function split(
end
alloc = :($Scratchpad(__ctx__, $T, Val($dims)))
push!(allocations, :($lhs = $alloc))
push!(private, lhs)
continue
end
end
Expand All @@ -180,7 +167,7 @@ function split(

# everything since the last `@synchronize`
if !isempty(current)
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private), false)
loop = WorkgroupLoop(current, allocations, false)

Check warning on line 170 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L170

Added line #L170 was not covered by tests
push!(new_stmts, emit(loop))
end
return new_stmts
Expand All @@ -192,9 +179,7 @@ function emit(loop)
body = Expr(:block, loop.stmts...)
loopexpr = quote
$(loop.allocations...)
$(loop.private_allocations...)
if __active_lane__
$(loop.indicies...)
$(unblock(body))
end
end
Expand Down

0 comments on commit a5f740a

Please sign in to comment.