|
| 1 | +module Defs |
| 2 | + |
| 3 | +using PartitionedArrays |
| 4 | +using PetscCall |
| 5 | +using LinearAlgebra |
| 6 | +using Test |
| 7 | + |
| 8 | +function spmv_petsc!(b,A,x) |
| 9 | + # Convert the input to petsc objects |
| 10 | + mat = Ref{PetscCall.Mat}() |
| 11 | + vec_b = Ref{PetscCall.Vec}() |
| 12 | + vec_x = Ref{PetscCall.Vec}() |
| 13 | + parts = linear_indices(partition(x)) |
| 14 | + petsc_comm = PetscCall.setup_petsc_comm(parts) |
| 15 | + args_A = PetscCall.MatCreateMPIAIJWithSplitArrays_args(A,petsc_comm) |
| 16 | + args_b = PetscCall.VecCreateMPIWithArray_args(copy(b),petsc_comm) |
| 17 | + args_x = PetscCall.VecCreateMPIWithArray_args(copy(x),petsc_comm) |
| 18 | + ownership = (args_A,args_b,args_x) |
| 19 | + PetscCall.@check_error_code PetscCall.MatCreateMPIAIJWithSplitArrays(args_A...,mat) |
| 20 | + PetscCall.@check_error_code PetscCall.MatAssemblyBegin(mat[],PetscCall.MAT_FINAL_ASSEMBLY) |
| 21 | + PetscCall.@check_error_code PetscCall.MatAssemblyEnd(mat[],PetscCall.MAT_FINAL_ASSEMBLY) |
| 22 | + PetscCall.@check_error_code PetscCall.VecCreateMPIWithArray(args_b...,vec_b) |
| 23 | + PetscCall.@check_error_code PetscCall.VecCreateMPIWithArray(args_x...,vec_x) |
| 24 | + # This line does the actual product |
| 25 | + PetscCall.@check_error_code PetscCall.MatMult(mat[],vec_x[],vec_b[]) |
| 26 | + # Move the result back to julia |
| 27 | + PetscCall.VecCreateMPIWithArray_args_reversed!(b,args_b) |
| 28 | + # Cleanup |
| 29 | + GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat) |
| 30 | + GC.@preserve ownership PetscCall.@check_error_code PetscCall.VecDestroy(vec_b) |
| 31 | + GC.@preserve ownership PetscCall.@check_error_code PetscCall.VecDestroy(vec_x) |
| 32 | + b |
| 33 | +end |
| 34 | + |
| 35 | +function test_spmm_petsc(A,B) |
| 36 | + parts = linear_indices(partition(A)) |
| 37 | + petsc_comm = PetscCall.setup_petsc_comm(parts) |
| 38 | + C1, cacheC = spmm(A,B,reuse=true) |
| 39 | + mat_A = Ref{PetscCall.Mat}() |
| 40 | + mat_B = Ref{PetscCall.Mat}() |
| 41 | + mat_C = Ref{PetscCall.Mat}() |
| 42 | + args_A = PetscCall.MatCreateMPIAIJWithSplitArrays_args(A,petsc_comm) |
| 43 | + args_B = PetscCall.MatCreateMPIAIJWithSplitArrays_args(B,petsc_comm) |
| 44 | + ownership = (args_A,args_B) |
| 45 | + PetscCall.@check_error_code PetscCall.MatCreateMPIAIJWithSplitArrays(args_A...,mat_A) |
| 46 | + PetscCall.@check_error_code PetscCall.MatCreateMPIAIJWithSplitArrays(args_B...,mat_B) |
| 47 | + PetscCall.@check_error_code PetscCall.MatProductCreate(mat_A[],mat_B[],C_NULL,mat_C) |
| 48 | + PetscCall.@check_error_code PetscCall.MatProductSetType(mat_C[],PetscCall.MATPRODUCT_AB) |
| 49 | + PetscCall.@check_error_code PetscCall.MatProductSetFromOptions(mat_C[]) |
| 50 | + PetscCall.@check_error_code PetscCall.MatProductSymbolic(mat_C[]) |
| 51 | + PetscCall.@check_error_code PetscCall.MatProductNumeric(mat_C[]) |
| 52 | + PetscCall.@check_error_code PetscCall.MatProductReplaceMats(mat_A[],mat_B[],C_NULL,mat_C[]) |
| 53 | + PetscCall.@check_error_code PetscCall.MatProductNumeric(mat_C[]) |
| 54 | + PetscCall.@check_error_code PetscCall.MatProductClear(mat_C[]) |
| 55 | + GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat_A) |
| 56 | + GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat_B) |
| 57 | + GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat_C) |
| 58 | +end |
| 59 | + |
| 60 | +function main(distribute,params) |
| 61 | + nodes_per_dir = params.nodes_per_dir |
| 62 | + parts_per_dir = params.parts_per_dir |
| 63 | + np = prod(parts_per_dir) |
| 64 | + ranks = LinearIndices((np,)) |> distribute |
| 65 | + A = PartitionedArrays.laplace_matrix(nodes_per_dir,parts_per_dir,ranks) |
| 66 | + rows = partition(axes(A,1)) |
| 67 | + cols = partition(axes(A,2)) |
| 68 | + x = pones(cols) |
| 69 | + b1 = pzeros(rows) |
| 70 | + b2 = pzeros(rows) |
| 71 | + mul!(b1,A,x) |
| 72 | + if ! PetscCall.initialized() |
| 73 | + PetscCall.init() |
| 74 | + end |
| 75 | + spmv_petsc!(b2,A,x) |
| 76 | + c = b1-b2 |
| 77 | + tol = 1.0e-12 |
| 78 | + @test norm(b1) > tol |
| 79 | + @test norm(b2) > tol |
| 80 | + @test norm(c)/norm(b1) < tol |
| 81 | + B = 2*A |
| 82 | + test_spmm_petsc(A,B) |
| 83 | +end |
| 84 | + |
| 85 | +end #module |
0 commit comments