diff --git a/src/compat/gpuarrays.jl b/src/compat/gpuarrays.jl index 058bb99e..467adfb3 100644 --- a/src/compat/gpuarrays.jl +++ b/src/compat/gpuarrays.jl @@ -64,3 +64,10 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), GPUArrays.mapreducedim!(f, $(op), getdata(r), getdata(A); init=neutral_element($(op), T)) end end + +function ComponentArray(nt::NamedTuple{names,<:Tuple{Vararg{Union{GPUArrays.AbstractGPUArray,GPUComponentArray}}}}) where {names} + T = recursive_eltype(nt) + gpuarray = getdata(first(nt)) + G = Base.typename(typeof(gpuarray)).wrapper # SciMLBase.parameterless_type(gpuarray) + return GPUArrays.adapt(G, ComponentArray(NamedTuple{names}(map(GPUArrays.adapt(Array{T}), nt)))) +end