Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.10] Forbid divergent execution of work-group barriers #558

Merged
merged 1 commit into from
Feb 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 6 additions & 24 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,8 @@ function transform_gpu!(def, constargs, force_inbounds, unsafe_indices)
end

struct WorkgroupLoop
indices::Vector{Any}
stmts::Vector{Any}
allocations::Vector{Any}
private_allocations::Vector{Any}
private::Set{Symbol}
terminated_in_sync::Bool
end

Expand All @@ -111,26 +108,18 @@ function find_sync(stmt)
end

# TODO proper handling of LineInfo
function split(
stmts,
indices = Any[], private = Set{Symbol}(),
)
function split(stmts)
# 1. Split the code into blocks separated by `@synchronize`
# 2. Aggregate `@index` expressions
# 3. Hoist allocations
# 4. Hoist uniforms

current = Any[]
allocations = Any[]
private_allocations = Any[]
new_stmts = Any[]
for stmt in stmts
has_sync = find_sync(stmt)
if has_sync
loop = WorkgroupLoop(deepcopy(indices), current, allocations, private_allocations, deepcopy(private), is_sync(stmt))
loop = WorkgroupLoop(current, allocations, is_sync(stmt))
push!(new_stmts, emit(loop))
allocations = Any[]
private_allocations = Any[]
current = Any[]
is_sync(stmt) && continue

Expand All @@ -142,7 +131,7 @@ function split(
function recurse(expr::Expr)
expr = unblock(expr)
if is_scope_construct(expr) && any(find_sync, expr.args)
new_args = unblock(split(expr.args, deepcopy(indices), deepcopy(private)))
new_args = unblock(split(expr.args))
return Expr(expr.head, new_args...)
else
return Expr(expr.head, map(recurse, expr.args)...)
Expand All @@ -156,14 +145,10 @@ function split(
push!(allocations, stmt)
continue
elseif @capture(stmt, @private lhs_ = rhs_)
push!(private, lhs)
push!(private_allocations, :($lhs = $rhs))
push!(allocations, :($lhs = $rhs))
continue
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
if @capture(rhs, @index(args__))
push!(indices, stmt)
continue
elseif @capture(rhs, @localmem(args__) | @uniform(args__))
if @capture(rhs, @localmem(args__) | @uniform(args__))
push!(allocations, stmt)
continue
elseif @capture(rhs, @private(T_, dims_))
Expand All @@ -175,7 +160,6 @@ function split(
end
alloc = :($Scratchpad(__ctx__, $T, Val($dims)))
push!(allocations, :($lhs = $alloc))
push!(private, lhs)
continue
end
end
Expand All @@ -185,7 +169,7 @@ function split(

# everything since the last `@synchronize`
if !isempty(current)
loop = WorkgroupLoop(deepcopy(indices), current, allocations, private_allocations, deepcopy(private), false)
loop = WorkgroupLoop(current, allocations, false)
push!(new_stmts, emit(loop))
end
return new_stmts
Expand All @@ -197,9 +181,7 @@ function emit(loop)
body = Expr(:block, loop.stmts...)
loopexpr = quote
$(loop.allocations...)
$(loop.private_allocations...)
if __active_lane__
$(loop.indices...)
$(unblock(body))
end
end
Expand Down
Loading