diff --git a/lib/cudadrv/execution.jl b/lib/cudadrv/execution.jl index fb4b67bb39..4725e5679b 100644 --- a/lib/cudadrv/execution.jl +++ b/lib/cudadrv/execution.jl @@ -2,15 +2,21 @@ export cudacall +# In contrast to `Base.RefValue` we just need a container for both pass-by-ref (Symbol), +# and pass-by-value (immutable structs). +mutable struct ArgBox{T} + const val::T +end + +function Base.unsafe_convert(P::Union{Type{Ptr{T}}, Type{Ptr{Cvoid}}}, b::ArgBox{T})::P where {T} + # TODO: What to do if T is not a leaftype (compare case 3 for RefValue) + return pointer_from_objref(b) +end ## device # pack arguments in a buffer that CUDA expects @inline @generated function pack_arguments(f::Function, args...) - for arg in args - isbitstype(arg) || throw(ArgumentError("Arguments to kernel should be bitstype.")) - end - ex = quote end # If f has N parameters, then kernelParams needs to be an array of N pointers. @@ -21,7 +27,7 @@ export cudacall arg_refs = Vector{Symbol}(undef, length(args)) for i in 1:length(args) arg_refs[i] = gensym() - push!(ex.args, :($(arg_refs[i]) = Base.RefValue(args[$i]))) + push!(ex.args, :($(arg_refs[i]) = $ArgBox(args[$i]))) end # generate an array with pointers diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index c105bc0b77..8f664ad65e 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -242,6 +242,11 @@ end CompilerConfig(target, params; kernel, name, always_inline) end +# a version of `sizeof` that returns the size of the argument we'll pass. +# for example, it supports Symbols where `sizeof(Symbol)` would fail. +argsize(x::Any) = sizeof(x) +argsize(::Type{Symbol}) = sizeof(Ptr{Cvoid}) + # compile to executable machine code function compile(@nospecialize(job::CompilerJob)) # lower to PTX @@ -281,7 +286,7 @@ function compile(@nospecialize(job::CompilerJob)) argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt !isghosttype(dt) && !Core.Compiler.isconstType(dt) end - param_usage = sum(sizeof, argtypes) + param_usage = sum(argsize, argtypes) param_limit = 4096 if cap >= v"7.0" && ptx >= v"8.1" param_limit = 32764 diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index c77588ad59..82a940db60 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -259,15 +259,6 @@ end call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]] call_args = Union{Expr,Symbol}[x[1] for x in zip(argexprs, to_pass) if x[2]] - # replace non-isbits arguments (they should be unused, or compilation would have failed) - # alternatively, make it possible to `launch` with non-isbits arguments. - for (i,dt) in enumerate(call_t) - if !isbitstype(dt) - call_t[i] = Ptr{Any} - call_args[i] = :C_NULL - end - end - # add the kernel state, passing an instance with a unique seed pushfirst!(call_t, KernelState) pushfirst!(call_args, :(KernelState(kernel.state.exception_info, make_seed(kernel)))) diff --git a/test/core/execution.jl b/test/core/execution.jl index f731e891de..94d92af213 100644 --- a/test/core/execution.jl +++ b/test/core/execution.jl @@ -626,6 +626,19 @@ end @test_throws "Kernel invocation uses too much parameter memory" @cuda kernel(ntuple(_->UInt64(1), 2^13)) end + @testset "symbols" begin + function pass_symbol(x, name) + i = name == :var ? 1 : 2 + x[i] = true + return nothing + end + x = CuArray([false, false]) + @cuda pass_symbol(x, :var) + @test Array(x) == [true, false] + @cuda pass_symbol(x, :not_var) + @test Array(x) == [true, true] + end + end ############################################################################################