Skip to content

Commit

Permalink
Forbid divergent execution of work-group barriers
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Feb 7, 2025
1 parent a6ae55b commit a48a158
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 11 deletions.
10 changes: 9 additions & 1 deletion src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ end
After a `@synchronize` statement all read and writes to global and local memory
from each thread in the workgroup are visible in from all other threads in the
workgroup.
!!! note
`@synchronize()` must be encountered by all workitems of a work-group executing the kernel or by none at all.
"""
macro synchronize()
return quote
Expand All @@ -314,10 +317,15 @@ workgroup. `cond` is not allowed to have any visible sideffects.
# Platform differences
- `GPU`: This synchronization will only occur if the `cond` evaluates.
- `CPU`: This synchronization will always occur.
!!! warn
This variant of the `@synchronize` macro violates the requirement that `@synchronize` must be encountered
by all workitems of a work-group executing the kernel or by none at all.
Since v`0.9.34` this version of the macro is deprecated and lowers to `@synchronize()`
"""
macro synchronize(cond)
return quote
$(esc(cond)) && $__synchronize()
$__synchronize()
end
end

Expand Down
163 changes: 153 additions & 10 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,165 @@ function transform_gpu!(def, constargs, force_inbounds)
end
end
pushfirst!(def[:args], :__ctx__)
body = def[:body]
new_stmts = Expr[]
body = MacroTools.flatten(def[:body])
stmts = body.args
push!(new_stmts, Expr(:aliasscope))
push!(new_stmts, :(__active_lane__ = $__validindex(__ctx__)))

Check warning on line 65 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L61-L65

Added lines #L61 - L65 were not covered by tests
if force_inbounds
body = quote
@inbounds $(body)
end
push!(new_stmts, Expr(:inbounds, true))

Check warning on line 67 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L67

Added line #L67 was not covered by tests
end
body = quote
if $__validindex(__ctx__)
$(body)
end
return nothing
append!(new_stmts, split(emit_gpu, body.args))
if force_inbounds
push!(new_stmts, Expr(:inbounds, :pop))

Check warning on line 71 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L69-L71

Added lines #L69 - L71 were not covered by tests
end
push!(new_stmts, Expr(:popaliasscope))
push!(new_stmts, :(return nothing))

Check warning on line 74 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L73-L74

Added lines #L73 - L74 were not covered by tests
def[:body] = Expr(
:let,
Expr(:block, let_constargs...),
body,
Expr(:block, new_stmts...),
)
return
end

struct WorkgroupLoop
indicies::Vector{Any}
stmts::Vector{Any}
allocations::Vector{Any}
private_allocations::Vector{Any}
private::Set{Symbol}
end

is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_))

Check warning on line 91 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L91

Added line #L91 was not covered by tests

function is_scope_construct(expr::Expr)
return expr.head === :block # ||

Check warning on line 94 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
# expr.head === :let
end

function find_sync(stmt)
result = false
postwalk(stmt) do expr
result |= is_sync(expr)
expr

Check warning on line 102 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L98-L102

Added lines #L98 - L102 were not covered by tests
end
return result

Check warning on line 104 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L104

Added line #L104 was not covered by tests
end

# TODO proper handling of LineInfo
function split(

Check warning on line 108 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L108

Added line #L108 was not covered by tests
emit,
stmts,
indicies = Any[], private = Set{Symbol}(),
)
# 1. Split the code into blocks separated by `@synchronize`
# 2. Aggregate `@index` expressions
# 3. Hoist allocations
# 4. Hoist uniforms

current = Any[]
allocations = Any[]
private_allocations = Any[]
new_stmts = Any[]
for stmt in stmts
has_sync = find_sync(stmt)
if has_sync
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private))
push!(new_stmts, emit(loop))
allocations = Any[]
private_allocations = Any[]
current = Any[]
is_sync(stmt) && continue

Check warning on line 130 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L118-L130

Added lines #L118 - L130 were not covered by tests

# Recurse into scope constructs
# TODO: This currently implements hard scoping
# probably need to implemet soft scoping
# by not deepcopying the environment.
recurse(x) = x
function recurse(expr::Expr)
expr = unblock(expr)
if is_scope_construct(expr) && any(find_sync, expr.args)
new_args = unblock(split(emit, expr.args, deepcopy(indicies), deepcopy(private)))
return Expr(expr.head, new_args...)

Check warning on line 141 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L136-L141

Added lines #L136 - L141 were not covered by tests
else
return Expr(expr.head, map(recurse, expr.args)...)

Check warning on line 143 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L143

Added line #L143 was not covered by tests
end
end
push!(new_stmts, recurse(stmt))
continue

Check warning on line 147 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L146-L147

Added lines #L146 - L147 were not covered by tests
end

if @capture(stmt, @uniform x_)
push!(allocations, stmt)
continue
elseif @capture(stmt, @private lhs_ = rhs_)
push!(private, lhs)
push!(private_allocations, :($lhs = $rhs))
continue
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
if @capture(rhs, @index(args__))
push!(indicies, stmt)
continue
elseif @capture(rhs, @localmem(args__) | @uniform(args__))
push!(allocations, stmt)
continue
elseif @capture(rhs, @private(T_, dims_))

Check warning on line 164 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L150-L164

Added lines #L150 - L164 were not covered by tests
# Implement the legacy `mem = @private T dims` as
# mem = Scratchpad(T, Val(dims))

if dims isa Integer
dims = (dims,)

Check warning on line 169 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L168-L169

Added lines #L168 - L169 were not covered by tests
end
alloc = :($Scratchpad(__ctx__, $T, Val($dims)))
push!(allocations, :($lhs = $alloc))
push!(private, lhs)
continue

Check warning on line 174 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L171-L174

Added lines #L171 - L174 were not covered by tests
end
end

push!(current, stmt)
end

Check warning on line 179 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L178-L179

Added lines #L178 - L179 were not covered by tests

# everything since the last `@synchronize`
if !isempty(current)
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private))
push!(new_stmts, emit(loop))

Check warning on line 184 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L182-L184

Added lines #L182 - L184 were not covered by tests
end
return new_stmts

Check warning on line 186 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L186

Added line #L186 was not covered by tests
end

function emit_gpu(loop)
stmts = Any[]
append!(stmts, loop.allocations)
for stmt in loop.private_allocations
if @capture(stmt, lhs_ = rhs_)
push!(stmts, :($lhs = $rhs))

Check warning on line 194 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L189-L194

Added lines #L189 - L194 were not covered by tests
else
error("@private $stmt not an assignment")

Check warning on line 196 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L196

Added line #L196 was not covered by tests
end
end

Check warning on line 198 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L198

Added line #L198 was not covered by tests

# don't emit empty loops
if !(isempty(loop.stmts) || all(s -> s isa LineNumberNode, loop.stmts))
body = Expr(:block, loop.stmts...)
body = postwalk(body) do expr
if @capture(expr, lhs_ = rhs_)
if lhs in loop.private
error("Can't assign to variables marked private")

Check warning on line 206 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L201-L206

Added lines #L201 - L206 were not covered by tests
end
end
return expr

Check warning on line 209 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L209

Added line #L209 was not covered by tests
end
loopexpr = quote
if __active_lane__
$(loop.indicies...)
$(unblock(body))

Check warning on line 214 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L211-L214

Added lines #L211 - L214 were not covered by tests
end
$__synchronize()

Check warning on line 216 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L216

Added line #L216 was not covered by tests
end
push!(stmts, loopexpr)

Check warning on line 218 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L218

Added line #L218 was not covered by tests
end

return unblock(Expr(:block, stmts...))

Check warning on line 221 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L221

Added line #L221 was not covered by tests
end

0 comments on commit a48a158

Please sign in to comment.