diff --git a/Project.toml b/Project.toml index b56f82a1..3f7be6ad 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Francesc Verdugo and contributors"] version = "0.5.10" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" CircularArrays = "7a955b69-7140-5f4e-a0ed-f168c5e2e749" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" @@ -18,6 +19,7 @@ SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] +Adapt = "4.3.0" BlockArrays = "0.16, 1" CircularArrays = "1" Distances = "0.10" diff --git a/src/PartitionedArrays.jl b/src/PartitionedArrays.jl index a402de0a..7d3e5d26 100644 --- a/src/PartitionedArrays.jl +++ b/src/PartitionedArrays.jl @@ -10,6 +10,7 @@ import MPI import IterativeSolvers import Distances using BlockArrays +using Adapt export length_to_ptrs! export rewind_ptrs! @@ -195,4 +196,5 @@ export nullspace_linear_elasticity! export near_nullspace_linear_elasticity include("gallery.jl") +include("adapt.jl") end # module diff --git a/src/adapt.jl b/src/adapt.jl new file mode 100644 index 00000000..3239951f --- /dev/null +++ b/src/adapt.jl @@ -0,0 +1,52 @@ + +function Adapt.adapt_structure(to,v::DebugArray) + v = map(v) do val + Adapt.adapt_structure(to,val) + end +end + +function Adapt.adapt_structure(to,v::MPIArray) + v = map(v) do val + Adapt.adapt_structure(to,val) + end +end + +function Adapt.adapt_structure(to,v::SplitMatrixBlocks) + own_own = Adapt.adapt(to,v.own_own) + own_ghost = Adapt.adapt(to,v.own_ghost) + ghost_ghost = Adapt.adapt(to,v.ghost_ghost) + ghost_own = Adapt.adapt(to,v.ghost_own) + split_matrix_blocks(own_own,own_ghost,ghost_own,ghost_ghost) +end + +function Adapt.adapt_structure(to,v::SplitVectorBlocks) + own = Adapt.adapt(to,v.own) + ghost = Adapt.adapt(to,v.ghost) + split_vector_blocks(own,ghost) +end + +function Adapt.adapt_structure(to,v::SplitVector) + blocks = Adapt.adapt(to,v.blocks) + perm = Adapt.adapt(to,v.permutation) + split_vector(blocks,perm) +end + +function Adapt.adapt_structure(to,v::JaggedArray) + data = Adapt.adapt_structure(to,v.data) + ptrs = Adapt.adapt_structure(to,v.ptrs) + jagged_array(data, ptrs) +end + +function Adapt.adapt_structure(to,v::SplitMatrix) + blocks = Adapt.adapt_structure(to,v.blocks) + col_per = v.col_permutation + row_per = v.row_permutation + split_matrix(blocks,row_per,col_per) +end + +function Adapt.adapt_structure(to,v::PSparseMatrix) + matrix_partition = Adapt.adapt_structure(to,v.matrix_partition) + col_par = v.col_partition + row_par = v.row_partition + PSparseMatrix(matrix_partition,row_par,col_par,v.assembled) +end diff --git a/test/adapt_tests.jl b/test/adapt_tests.jl new file mode 100644 index 00000000..df5c835f --- /dev/null +++ b/test/adapt_tests.jl @@ -0,0 +1,64 @@ +using Test +using PartitionedArrays +using Adapt + +struct FakeCuVector{A} <: AbstractVector{Float64} + vector::A +end + +Base.size(v::FakeCuVector) = size(v.vector) +Base.getindex(v::FakeCuVector,i::Integer) = v.vector[i] + +function Adapt.adapt_storage(::Type{<:FakeCuVector},x::AbstractArray) + FakeCuVector(x) +end + +function adapt_tests(distribute) + + rank = distribute(LinearIndices((2,2))) + + a = [[1,2],[3,4,5],Int[],[3,4]] + b = JaggedArray(a) + c = deepcopy(b) + + c = Adapt.adapt(FakeCuVector,c) + + @test typeof(c.data) == FakeCuVector{typeof(b.data)} + @test typeof(c.ptrs) == FakeCuVector{typeof(b.ptrs)} + @test typeof(c).name.wrapper == GenericJaggedArray + + a = [1,2,3,4,5] + b = deepcopy(a) + b = Adapt.adapt(FakeCuVector,b) + @test typeof(b) == FakeCuVector{typeof(a)} + @test b.vector == a + + own = [1,2,3,4] + ghost = [5,6,7,8] + block_a = split_vector_blocks(own, ghost) + block_b = deepcopy(block_a) + block_b = Adapt.adapt(FakeCuVector,block_b) + @test block_b.own.vector == block_a.own + @test block_b.ghost.vector == block_a.ghost + @test typeof(block_b.own) == FakeCuVector{typeof(block_a.own)} + @test typeof(block_b.ghost) == FakeCuVector{typeof(block_a.ghost)} + + + a = split_vector(block_a,[1,2,3,4,5,6,7,8]) + b = deepcopy(a) + b = Adapt.adapt(FakeCuVector,b) + + @test b.blocks.own.vector == a.blocks.own + @test b.blocks.ghost.vector == a.blocks.ghost + @test b.permutation.vector == a.permutation + + + a = distribute([[1,1,1],[2,2,2],[3,3,3],[4,4,4]]) + b = distribute([[1,1,1],[2,2,2],[3,3,3],[4,4,4]]) + b = Adapt.adapt(FakeCuVector,b) + + map(a,b) do val_a,val_b + @test typeof(val_b) == FakeCuVector{typeof(val_a)} + @test val_b.vector == val_a + end +end \ No newline at end of file diff --git a/test/debug_array/adapt_tests.jl b/test/debug_array/adapt_tests.jl new file mode 100644 index 00000000..998e367e --- /dev/null +++ b/test/debug_array/adapt_tests.jl @@ -0,0 +1,9 @@ +module DebugArrayAdaptTests + +using PartitionedArrays + +include(joinpath("..","adapt_tests.jl")) + +with_debug(adapt_tests) + +end # module diff --git a/test/debug_array/runtests.jl b/test/debug_array/runtests.jl index 2c1a61ab..b1371725 100644 --- a/test/debug_array/runtests.jl +++ b/test/debug_array/runtests.jl @@ -23,4 +23,6 @@ using PartitionedArrays @testset "fem_example" begin include("fem_example.jl") end +@testset "adapt" begin include("adapt_tests.jl") end + end #module diff --git a/test/mpi_array/adapt_tests.jl b/test/mpi_array/adapt_tests.jl new file mode 100644 index 00000000..ee834c51 --- /dev/null +++ b/test/mpi_array/adapt_tests.jl @@ -0,0 +1,5 @@ +using MPI +include("run_mpi_driver.jl") +file = joinpath(@__DIR__,"drivers","adapt_tests.jl") +run_mpi_driver(file;procs=4) + diff --git a/test/mpi_array/drivers/adapt_tests.jl b/test/mpi_array/drivers/adapt_tests.jl new file mode 100644 index 00000000..ce078041 --- /dev/null +++ b/test/mpi_array/drivers/adapt_tests.jl @@ -0,0 +1,9 @@ +module MPIArrayAdaptTests + +using PartitionedArrays + +include(joinpath("..","..","adapt_tests.jl")) + +with_mpi(adapt_tests) + +end # module diff --git a/test/mpi_array/runtests.jl b/test/mpi_array/runtests.jl index 26a3a5d3..95201c44 100644 --- a/test/mpi_array/runtests.jl +++ b/test/mpi_array/runtests.jl @@ -13,5 +13,6 @@ using PartitionedArrays @testset "p_timer_tests" begin include("p_timer_tests.jl") end @testset "fdm_example" begin include("fdm_example.jl") end @testset "fem_example" begin include("fem_example.jl") end +@testset "adapt" begin include("adapt_tests.jl") end end #module