Skip to content

Commit

Permalink
Add constructor for GPUArrays (#166)
Browse files Browse the repository at this point in the history
* Add constructor for GPUArrays

* Use recursive_eltype

* Add constructor for GPUComponentArrays

* Dont't convert eltype on GPU

* Use SciML definition of parameterless_type
  • Loading branch information
ldeso authored Oct 6, 2022
1 parent a7d6f7d commit e991612
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/compat/gpuarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,10 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
GPUArrays.mapreducedim!(f, $(op), getdata(r), getdata(A); init=neutral_element($(op), T))
end
end

function ComponentArray(nt::NamedTuple{names,<:Tuple{Vararg{Union{GPUArrays.AbstractGPUArray,GPUComponentArray}}}}) where {names}
T = recursive_eltype(nt)
gpuarray = getdata(first(nt))
G = Base.typename(typeof(gpuarray)).wrapper # SciMLBase.parameterless_type(gpuarray)
return GPUArrays.adapt(G, ComponentArray(NamedTuple{names}(map(GPUArrays.adapt(Array{T}), nt))))
end

0 comments on commit e991612

Please sign in to comment.