Skip to content

Commit

Permalink
Conversions of StokesArrays and ThermalArrays from GPU to CPU (#136)
Browse files Browse the repository at this point in the history
* gpu2cpu conversions

* improve Array conversions

* expand test suit

* typos & formatting

* typo
  • Loading branch information
albert-de-montserrat authored Apr 27, 2024
1 parent 2117f77 commit 8afbd4b
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/MetaJustRelax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ function environment!(model::PS_Setup{T,N}) where {T,N}
ViscoElastoPlastic,
solve!

include(joinpath(@__DIR__, "array_conversions.jl"))
export Array

include(joinpath(@__DIR__, "Utils.jl"))
export @allocate, @add, @idx, @copy
export @velocity,
Expand Down
60 changes: 60 additions & 0 deletions src/array_conversions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Device trait system

abstract type DeviceTrait end
struct CPUDeviceTrait <: DeviceTrait end
struct NonCPUDeviceTrait <: DeviceTrait end

@inline iscpu(::Array) = CPUDeviceTrait()
@inline iscpu(::AbstractArray) = NonCPUDeviceTrait()
@inline iscpu(::T) where {T} = throw(ArgumentError("Unknown device"))

@inline iscpu(::Velocity{Array{T,N}}) where {T,N} = CPUDeviceTrait()
@inline iscpu(::Velocity{AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()

@inline iscpu(::SymmetricTensor{Array{T,N}}) where {T,N} = CPUDeviceTrait()
@inline iscpu(::SymmetricTensor{AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()

@inline iscpu(::Residual{Array{T,N}}) where {T,N} = CPUDeviceTrait()
@inline iscpu(::Residual{AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()

@inline iscpu(::ThermalArrays{Array{T,N}}) where {T,N} = CPUDeviceTrait()
@inline iscpu(::ThermalArrays{AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()

@inline iscpu(::StokesArrays{M,A,B,C,Array{T,N},nDim}) where {M,A,B,C,T,N,nDim} =
CPUDeviceTrait()
@inline iscpu(::StokesArrays{M,A,B,C,AbstractArray{T,N},nDim}) where {M,A,B,C,T,N,nDim} =
NonCPUDeviceTrait()

## Conversion of structs to CPU

@inline remove_parameters(::T) where {T} = Base.typename(T).wrapper

function Array(
x::T
) where {T<:Union{StokesArrays,SymmetricTensor,ThermalArrays,Velocity,Residual}}
return Array(iscpu(x), x)
end

Array(::CPUDeviceTrait, x) = x

function Array(
::NonCPUDeviceTrait, x::T
) where {T<:Union{SymmetricTensor,ThermalArrays,Velocity,Residual}}
nfields = fieldcount(T)
cpu_fields = ntuple(Val(nfields)) do i
Base.@_inline_meta
Array(getfield(x, i))
end
T_clean = remove_parameters(x)
return T_clean(cpu_fields...)
end

function Array(::NonCPUDeviceTrait, x::StokesArrays{T,A,B,C,M,nDim}) where {T,A,B,C,M,nDim}
nfields = fieldcount(StokesArrays)
cpu_fields = ntuple(Val(nfields)) do i
Base.@_inline_meta
Array(getfield(x, i))
end
T_clean = remove_parameters(x)
return T_clean(cpu_fields...)
end
4 changes: 2 additions & 2 deletions src/stokes/MetaStokes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ function make_stokes_struct!(; name::Symbol=:StokesArrays)
)
end

function $(name)(args::Vararg{T,N}) where {T<:AbstractArray,N}
function $(name)(args::Vararg{Any,N}) where {N}
return new{
ViscoElastic,
typeof(args[4]),
typeof(args[3]),
typeof(args[5]),
typeof(args[end]),
typeof(args[1]),
Expand Down
24 changes: 24 additions & 0 deletions test/test_arrays_conversions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using JustRelax, Test
model = PS_Setup(:Threads, Float64, 2)
environment!(model)

@testset "Array conversions" begin
ni = 10, 10
stokes = StokesArrays(ni, ViscoElastic)
thermal = ThermalArrays(ni)

@test Array(stokes.V) isa Velocity{Array{T, N}} where {T, N}
@test Array(stokes.τ) isa SymmetricTensor{Array{T, N}} where {T, N}
@test Array(stokes.R) isa Residual{Array{T, N}} where {T, N}
@test Array(stokes.P) isa Array{T, N} where {T, N}
@test Array(stokes) isa StokesArrays
@test Array(thermal) isa ThermalArrays{Array{T, N}} where {T, N}

@test JustRelax.iscpu(stokes.V) isa JustRelax.CPUDeviceTrait
@test JustRelax.iscpu(stokes.τ) isa JustRelax.CPUDeviceTrait
@test JustRelax.iscpu(stokes.R) isa JustRelax.CPUDeviceTrait
@test JustRelax.iscpu(stokes.P) isa JustRelax.CPUDeviceTrait
@test JustRelax.iscpu(stokes) isa JustRelax.CPUDeviceTrait
@test JustRelax.iscpu(thermal) isa JustRelax.CPUDeviceTrait
@test_throws ArgumentError("Unknown device") JustRelax.iscpu("potato")
end

0 comments on commit 8afbd4b

Please sign in to comment.