Skip to content

Commit e991612

Browse files
authored
Add constructor for GPUArrays (#166)
* Add constructor for GPUArrays * Use recursive_eltype * Add constructor for GPUComponentArrays * Dont't convert eltype on GPU * Use SciML definition of parameterless_type
1 parent a7d6f7d commit e991612

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/compat/gpuarrays.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,10 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
6464
GPUArrays.mapreducedim!(f, $(op), getdata(r), getdata(A); init=neutral_element($(op), T))
6565
end
6666
end
67+
68+
function ComponentArray(nt::NamedTuple{names,<:Tuple{Vararg{Union{GPUArrays.AbstractGPUArray,GPUComponentArray}}}}) where {names}
69+
T = recursive_eltype(nt)
70+
gpuarray = getdata(first(nt))
71+
G = Base.typename(typeof(gpuarray)).wrapper # SciMLBase.parameterless_type(gpuarray)
72+
return GPUArrays.adapt(G, ComponentArray(NamedTuple{names}(map(GPUArrays.adapt(Array{T}), nt))))
73+
end

0 commit comments

Comments
 (0)