@@ -32,6 +32,31 @@ function spmv_petsc!(b,A,x)
32
32
b
33
33
end
34
34
35
+ function test_spmm_petsc (A,B)
36
+ parts = linear_indices (partition (A))
37
+ petsc_comm = PetscCall. setup_petsc_comm (parts)
38
+ C1, cacheC = spmm (A,B,reuse= true )
39
+ mat_A = Ref {PetscCall.Mat} ()
40
+ mat_B = Ref {PetscCall.Mat} ()
41
+ mat_C = Ref {PetscCall.Mat} ()
42
+ args_A = PetscCall. MatCreateMPIAIJWithSplitArrays_args (A,petsc_comm)
43
+ args_B = PetscCall. MatCreateMPIAIJWithSplitArrays_args (B,petsc_comm)
44
+ ownership = (args_A,args_B)
45
+ PetscCall. @check_error_code PetscCall. MatCreateMPIAIJWithSplitArrays (args_A... ,mat_A)
46
+ PetscCall. @check_error_code PetscCall. MatCreateMPIAIJWithSplitArrays (args_B... ,mat_B)
47
+ PetscCall. @check_error_code PetscCall. MatProductCreate (mat_A[],mat_B[],C_NULL ,mat_C)
48
+ PetscCall. @check_error_code PetscCall. MatProductSetType (mat_C[],PetscCall. MATPRODUCT_AB)
49
+ PetscCall. @check_error_code PetscCall. MatProductSetFromOptions (mat_C[])
50
+ PetscCall. @check_error_code PetscCall. MatProductSymbolic (mat_C[])
51
+ PetscCall. @check_error_code PetscCall. MatProductNumeric (mat_C[])
52
+ PetscCall. @check_error_code PetscCall. MatProductReplaceMats (mat_A[],mat_B[],C_NULL ,mat_C[])
53
+ PetscCall. @check_error_code PetscCall. MatProductNumeric (mat_C[])
54
+ PetscCall. @check_error_code PetscCall. MatProductClear (mat_C[])
55
+ GC. @preserve ownership PetscCall. @check_error_code PetscCall. MatDestroy (mat_A)
56
+ GC. @preserve ownership PetscCall. @check_error_code PetscCall. MatDestroy (mat_B)
57
+ GC. @preserve ownership PetscCall. @check_error_code PetscCall. MatDestroy (mat_C)
58
+ end
59
+
35
60
function main (distribute,params)
36
61
nodes_per_dir = params. nodes_per_dir
37
62
parts_per_dir = params. parts_per_dir
@@ -53,6 +78,8 @@ function main(distribute,params)
53
78
@test norm (b1) > tol
54
79
@test norm (b2) > tol
55
80
@test norm (c)/ norm (b1) < tol
81
+ B = 2 * A
82
+ test_spmm_petsc (A,B)
56
83
end
57
84
58
85
end # module
0 commit comments