Skip to content

Commit 44b5f4c

Browse files
authored
Merge pull request #8 from fverdugo/api_tests
Api tests
2 parents f7f25ca + dae2543 commit 44b5f4c

File tree

6 files changed

+326
-144
lines changed

6 files changed

+326
-144
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
1515

1616
[compat]
1717
MPI = "0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20"
18-
PETSc_jll = "=3.13.4, =3.15.2"
18+
PETSc_jll = "3"
1919
PartitionedArrays = "0.4"
2020
Preferences = "1"
2121
SparseArrays = "1"

src/api.jl

Lines changed: 200 additions & 141 deletions
Large diffs are not rendered by default.

test/mpi_array/api_test.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
module ApiTests
2+
3+
using Test
4+
using MPI
5+
using PartitionedArrays
6+
7+
repodir = normpath(joinpath(@__DIR__,"..",".."))
8+
9+
defs = joinpath(repodir,"test","mpi_array","api_test_defs.jl")
10+
11+
include(defs)
12+
params = (;nodes_per_dir=(10,10,10),parts_per_dir=(1,1,1))
13+
with_mpi(dist->Defs.main(dist,params))
14+
15+
code = quote
16+
using MPI; MPI.Init()
17+
using PartitionedArrays
18+
include($defs)
19+
params = (;nodes_per_dir=(10,10,10),parts_per_dir=(2,2,2))
20+
with_mpi(dist->Defs.main(dist,params))
21+
end
22+
run(`$(mpiexec()) -np 8 $(Base.julia_cmd()) --project=$repodir -e $code`)
23+
24+
code = quote
25+
using MPI; MPI.Init()
26+
using PartitionedArrays
27+
include($defs)
28+
params = (;nodes_per_dir=(10,10,10),parts_per_dir=(2,4,1))
29+
with_mpi(dist->Defs.main(dist,params))
30+
end
31+
run(`$(mpiexec()) -np 8 $(Base.julia_cmd()) --project=$repodir -e $code`)
32+
33+
end # module

test/mpi_array/api_test_defs.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
module Defs
2+
3+
using PartitionedArrays
4+
using PetscCall
5+
using LinearAlgebra
6+
using Test
7+
8+
function spmv_petsc!(b,A,x)
9+
# Convert the input to petsc objects
10+
mat = Ref{PetscCall.Mat}()
11+
vec_b = Ref{PetscCall.Vec}()
12+
vec_x = Ref{PetscCall.Vec}()
13+
parts = linear_indices(partition(x))
14+
petsc_comm = PetscCall.setup_petsc_comm(parts)
15+
args_A = PetscCall.MatCreateMPIAIJWithSplitArrays_args(A,petsc_comm)
16+
args_b = PetscCall.VecCreateMPIWithArray_args(copy(b),petsc_comm)
17+
args_x = PetscCall.VecCreateMPIWithArray_args(copy(x),petsc_comm)
18+
ownership = (args_A,args_b,args_x)
19+
PetscCall.@check_error_code PetscCall.MatCreateMPIAIJWithSplitArrays(args_A...,mat)
20+
PetscCall.@check_error_code PetscCall.MatAssemblyBegin(mat[],PetscCall.MAT_FINAL_ASSEMBLY)
21+
PetscCall.@check_error_code PetscCall.MatAssemblyEnd(mat[],PetscCall.MAT_FINAL_ASSEMBLY)
22+
PetscCall.@check_error_code PetscCall.VecCreateMPIWithArray(args_b...,vec_b)
23+
PetscCall.@check_error_code PetscCall.VecCreateMPIWithArray(args_x...,vec_x)
24+
# This line does the actual product
25+
PetscCall.@check_error_code PetscCall.MatMult(mat[],vec_x[],vec_b[])
26+
# Move the result back to julia
27+
PetscCall.VecCreateMPIWithArray_args_reversed!(b,args_b)
28+
# Cleanup
29+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat)
30+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.VecDestroy(vec_b)
31+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.VecDestroy(vec_x)
32+
b
33+
end
34+
35+
function test_spmm_petsc(A,B)
36+
parts = linear_indices(partition(A))
37+
petsc_comm = PetscCall.setup_petsc_comm(parts)
38+
C1, cacheC = spmm(A,B,reuse=true)
39+
mat_A = Ref{PetscCall.Mat}()
40+
mat_B = Ref{PetscCall.Mat}()
41+
mat_C = Ref{PetscCall.Mat}()
42+
args_A = PetscCall.MatCreateMPIAIJWithSplitArrays_args(A,petsc_comm)
43+
args_B = PetscCall.MatCreateMPIAIJWithSplitArrays_args(B,petsc_comm)
44+
ownership = (args_A,args_B)
45+
PetscCall.@check_error_code PetscCall.MatCreateMPIAIJWithSplitArrays(args_A...,mat_A)
46+
PetscCall.@check_error_code PetscCall.MatCreateMPIAIJWithSplitArrays(args_B...,mat_B)
47+
PetscCall.@check_error_code PetscCall.MatProductCreate(mat_A[],mat_B[],C_NULL,mat_C)
48+
PetscCall.@check_error_code PetscCall.MatProductSetType(mat_C[],PetscCall.MATPRODUCT_AB)
49+
PetscCall.@check_error_code PetscCall.MatProductSetFromOptions(mat_C[])
50+
PetscCall.@check_error_code PetscCall.MatProductSymbolic(mat_C[])
51+
PetscCall.@check_error_code PetscCall.MatProductNumeric(mat_C[])
52+
PetscCall.@check_error_code PetscCall.MatProductReplaceMats(mat_A[],mat_B[],C_NULL,mat_C[])
53+
PetscCall.@check_error_code PetscCall.MatProductNumeric(mat_C[])
54+
PetscCall.@check_error_code PetscCall.MatProductClear(mat_C[])
55+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat_A)
56+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat_B)
57+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat_C)
58+
end
59+
60+
function main(distribute,params)
61+
nodes_per_dir = params.nodes_per_dir
62+
parts_per_dir = params.parts_per_dir
63+
np = prod(parts_per_dir)
64+
ranks = LinearIndices((np,)) |> distribute
65+
A = PartitionedArrays.laplace_matrix(nodes_per_dir,parts_per_dir,ranks)
66+
rows = partition(axes(A,1))
67+
cols = partition(axes(A,2))
68+
x = pones(cols)
69+
b1 = pzeros(rows)
70+
b2 = pzeros(rows)
71+
mul!(b1,A,x)
72+
if ! PetscCall.initialized()
73+
PetscCall.init()
74+
end
75+
spmv_petsc!(b2,A,x)
76+
c = b1-b2
77+
tol = 1.0e-12
78+
@test norm(b1) > tol
79+
@test norm(b2) > tol
80+
@test norm(c)/norm(b1) < tol
81+
B = 2*A
82+
test_spmm_petsc(A,B)
83+
end
84+
85+
end #module

test/mpi_array/ksp_test.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
module KspTests
2+
13
using MPI
24
using Test
35

@@ -19,6 +21,5 @@ end
1921

2022
run(`$mpiexec_cmd -np 3 $(Base.julia_cmd()) --project=$repodir -e $code`)
2123

22-
nothing
23-
24+
end # module
2425

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ module PetscCallTest
33
using PetscCall
44
using Test
55

6+
@testset "API" begin
7+
@testset "PartitionedArrays: MPIArray" begin include("mpi_array/api_test.jl") end
8+
end
9+
610
@testset "KSP" begin
711
@testset "Sequential" begin include("ksp_test.jl") end
812
@testset "PartitionedArrays: DebugArray" begin include("debug_array/ksp_test.jl") end

0 commit comments

Comments
 (0)