Skip to content

Commit dae2543

Browse files
committed
testing mat multiplication routines
1 parent 37d6521 commit dae2543

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

test/mpi_array/api_test_defs.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,31 @@ function spmv_petsc!(b,A,x)
3232
b
3333
end
3434

35+
function test_spmm_petsc(A,B)
36+
parts = linear_indices(partition(A))
37+
petsc_comm = PetscCall.setup_petsc_comm(parts)
38+
C1, cacheC = spmm(A,B,reuse=true)
39+
mat_A = Ref{PetscCall.Mat}()
40+
mat_B = Ref{PetscCall.Mat}()
41+
mat_C = Ref{PetscCall.Mat}()
42+
args_A = PetscCall.MatCreateMPIAIJWithSplitArrays_args(A,petsc_comm)
43+
args_B = PetscCall.MatCreateMPIAIJWithSplitArrays_args(B,petsc_comm)
44+
ownership = (args_A,args_B)
45+
PetscCall.@check_error_code PetscCall.MatCreateMPIAIJWithSplitArrays(args_A...,mat_A)
46+
PetscCall.@check_error_code PetscCall.MatCreateMPIAIJWithSplitArrays(args_B...,mat_B)
47+
PetscCall.@check_error_code PetscCall.MatProductCreate(mat_A[],mat_B[],C_NULL,mat_C)
48+
PetscCall.@check_error_code PetscCall.MatProductSetType(mat_C[],PetscCall.MATPRODUCT_AB)
49+
PetscCall.@check_error_code PetscCall.MatProductSetFromOptions(mat_C[])
50+
PetscCall.@check_error_code PetscCall.MatProductSymbolic(mat_C[])
51+
PetscCall.@check_error_code PetscCall.MatProductNumeric(mat_C[])
52+
PetscCall.@check_error_code PetscCall.MatProductReplaceMats(mat_A[],mat_B[],C_NULL,mat_C[])
53+
PetscCall.@check_error_code PetscCall.MatProductNumeric(mat_C[])
54+
PetscCall.@check_error_code PetscCall.MatProductClear(mat_C[])
55+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat_A)
56+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat_B)
57+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat_C)
58+
end
59+
3560
function main(distribute,params)
3661
nodes_per_dir = params.nodes_per_dir
3762
parts_per_dir = params.parts_per_dir
@@ -53,6 +78,8 @@ function main(distribute,params)
5378
@test norm(b1) > tol
5479
@test norm(b2) > tol
5580
@test norm(c)/norm(b1) < tol
81+
B = 2*A
82+
test_spmm_petsc(A,B)
5683
end
5784

5885
end #module

0 commit comments

Comments
 (0)