Skip to content

Commit

Permalink
Support passing symbols as arguments (#2624)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Besard <[email protected]>
  • Loading branch information
vchuravy and maleadt authored Jan 20, 2025
1 parent 7bee37c commit 4bec614
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 15 deletions.
16 changes: 11 additions & 5 deletions lib/cudadrv/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@

export cudacall

# In contrast to `Base.RefValue` we just need a container for both pass-by-ref (Symbol),
# and pass-by-value (immutable structs).
mutable struct ArgBox{T}
const val::T
end

function Base.unsafe_convert(P::Union{Type{Ptr{T}}, Type{Ptr{Cvoid}}}, b::ArgBox{T})::P where {T}
# TODO: What to do if T is not a leaftype (compare case 3 for RefValue)
return pointer_from_objref(b)
end

## device

# pack arguments in a buffer that CUDA expects
@inline @generated function pack_arguments(f::Function, args...)
for arg in args
isbitstype(arg) || throw(ArgumentError("Arguments to kernel should be bitstype."))
end

ex = quote end

# If f has N parameters, then kernelParams needs to be an array of N pointers.
Expand All @@ -21,7 +27,7 @@ export cudacall
arg_refs = Vector{Symbol}(undef, length(args))
for i in 1:length(args)
arg_refs[i] = gensym()
push!(ex.args, :($(arg_refs[i]) = Base.RefValue(args[$i])))
push!(ex.args, :($(arg_refs[i]) = $ArgBox(args[$i])))
end

# generate an array with pointers
Expand Down
7 changes: 6 additions & 1 deletion src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ end
CompilerConfig(target, params; kernel, name, always_inline)
end

# a version of `sizeof` that returns the size of the argument we'll pass.
# for example, it supports Symbols where `sizeof(Symbol)` would fail.
argsize(x::Any) = sizeof(x)
argsize(::Type{Symbol}) = sizeof(Ptr{Cvoid})

# compile to executable machine code
function compile(@nospecialize(job::CompilerJob))
# lower to PTX
Expand Down Expand Up @@ -281,7 +286,7 @@ function compile(@nospecialize(job::CompilerJob))
argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt
!isghosttype(dt) && !Core.Compiler.isconstType(dt)
end
param_usage = sum(sizeof, argtypes)
param_usage = sum(argsize, argtypes)
param_limit = 4096
if cap >= v"7.0" && ptx >= v"8.1"
param_limit = 32764
Expand Down
9 changes: 0 additions & 9 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,6 @@ end
call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]]
call_args = Union{Expr,Symbol}[x[1] for x in zip(argexprs, to_pass) if x[2]]

# replace non-isbits arguments (they should be unused, or compilation would have failed)
# alternatively, make it possible to `launch` with non-isbits arguments.
for (i,dt) in enumerate(call_t)
if !isbitstype(dt)
call_t[i] = Ptr{Any}
call_args[i] = :C_NULL
end
end

# add the kernel state, passing an instance with a unique seed
pushfirst!(call_t, KernelState)
pushfirst!(call_args, :(KernelState(kernel.state.exception_info, make_seed(kernel))))
Expand Down
13 changes: 13 additions & 0 deletions test/core/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,19 @@ end
@test_throws "Kernel invocation uses too much parameter memory" @cuda kernel(ntuple(_->UInt64(1), 2^13))
end

@testset "symbols" begin
function pass_symbol(x, name)
i = name == :var ? 1 : 2
x[i] = true
return nothing
end
x = CuArray([false, false])
@cuda pass_symbol(x, :var)
@test Array(x) == [true, false]
@cuda pass_symbol(x, :not_var)
@test Array(x) == [true, true]
end

end

############################################################################################
Expand Down

0 comments on commit 4bec614

Please sign in to comment.