Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to mpi f08 in all files #60

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/mpifx_abort.fpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
!> Contains wrapper for \c MPI_ABORT.
module mpifx_abort_module
use mpi
use mpi_f08, only : mpi_abort
use mpifx_comm_module, only : mpifx_comm
use mpifx_helper_module, only : handle_errorflag
implicit none
Expand Down Expand Up @@ -47,7 +47,7 @@ contains
errorcode0 = -1
end if

call mpi_abort(mycomm%id, errorcode0, error0)
call mpi_abort(mycomm%comm, errorcode0, error0)
call handle_errorflag(error0, "MPI_ABORT in mpifx_abort", error)

end subroutine mpifx_abort
Expand Down
7 changes: 4 additions & 3 deletions lib/mpifx_allgather.fpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

!> Contains wrapper for \c MPI_ALLGATHER
module mpifx_allgather_module
use mpi
use mpi_f08, only : mpi_allgather, mpi_character, mpi_complex, mpi_double_complex,&
& mpi_double_precision, mpi_integer, mpi_logical, mpi_real
use mpifx_comm_module, only : mpifx_comm
use mpifx_helper_module, only : dp, handle_errorflag, sp
implicit none
Expand Down Expand Up @@ -122,7 +123,7 @@ contains
@:ASSERT(size(recv) == ${SIZE}$ * mycomm%size)
@:ASSERT(size(recv, dim=${RANK}$) == size(send, dim=${RANK}$) * mycomm%size)

call mpi_allgather(send, ${COUNT}$, ${MPITYPE}$, recv, ${COUNT}$, ${MPITYPE}$, mycomm%id,&
call mpi_allgather(send, ${COUNT}$, ${MPITYPE}$, recv, ${COUNT}$, ${MPITYPE}$, mycomm%comm,&
& error0)
call handle_errorflag(error0, 'MPI_ALLGATHER in mpifx_allgather_${SUFFIX}$', error)

Expand Down Expand Up @@ -162,7 +163,7 @@ contains
@:ASSERT(size(recv, dim=${RANK + 1}$) == mycomm%size)

call mpi_allgather(send, ${COUNT}$, ${MPITYPE}$, recv, ${COUNT}$, ${MPITYPE}$,&
& mycomm%id, error0)
& mycomm%comm, error0)
call handle_errorflag(error0, 'MPI_ALLGATHER in mpifx_allgather_${SUFFIX}$', error)

end subroutine mpifx_allgather_${SUFFIX}$
Expand Down
7 changes: 4 additions & 3 deletions lib/mpifx_allgatherv.fpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
end if

call mpi_allgatherv(send, size(send), ${MPI_TYPE}$, recv, recvcounts, displs0, &
& ${MPI_TYPE}$, mycomm%id, error0)
& ${MPI_TYPE}$, mycomm%comm, error0)

call handle_errorflag(error0, "MPI_ALLGATHERV in mpifx_allgatherv_${SUFFIX}$", error)

Expand Down Expand Up @@ -96,7 +96,7 @@
end if

call mpi_allgatherv(send, ${SEND_BUFFER_SIZE}$, ${MPI_TYPE}$, recv, recvcounts, displs0, &
& ${MPI_TYPE}$, mycomm%id, error0)
& ${MPI_TYPE}$, mycomm%comm, error0)

call handle_errorflag(error0, "MPI_ALLGATHERV in mpifx_allgatherv_${SUFFIX}$", error)

Expand All @@ -106,7 +106,8 @@

!> Contains wrapper for \c MPI_allgatherv
module mpifx_allgatherv_module
use mpi
use mpi_f08, only : mpi_allgatherv, mpi_character, mpi_complex, mpi_double_complex,&
& mpi_double_precision, mpi_integer, mpi_logical, mpi_real
use mpifx_comm_module, only : mpifx_comm
use mpifx_helper_module, only : dp, handle_errorflag, sp
implicit none
Expand Down
73 changes: 62 additions & 11 deletions lib/mpifx_allreduce.fpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

!> Contains wrapper for \c MPI_ALLREDUCE.
module mpifx_allreduce_module
use mpi
use mpi_f08, only : mpi_allreduce, mpi_complex, mpi_double_complex, mpi_double_precision,&
& mpi_in_place, mpi_integer, mpi_logical, mpi_op, mpi_real
use mpifx_comm_module, only : mpifx_comm
use mpifx_helper_module, only : dp, handle_errorflag, sp
implicit none
Expand Down Expand Up @@ -50,7 +51,8 @@ module mpifx_allreduce_module
interface mpifx_allreduce
#:for TYPE in TYPES
#:for RANK in RANKS
module procedure mpifx_allreduce_${TYPE_ABBREVS[TYPE]}$${RANK}$
module procedure mpifx_allreduce_with_type${TYPE_ABBREVS[TYPE]}$${RANK}$
module procedure mpifx_allreduce_with_id${TYPE_ABBREVS[TYPE]}$${RANK}$
#:endfor
#:endfor
end interface mpifx_allreduce
Expand Down Expand Up @@ -94,7 +96,8 @@ module mpifx_allreduce_module
interface mpifx_allreduceip
#:for TYPE in TYPES
#:for RANK in RANKS
module procedure mpifx_allreduceip_${TYPE_ABBREVS[TYPE]}$${RANK}$
module procedure mpifx_allreduceip_with_type${TYPE_ABBREVS[TYPE]}$${RANK}$
module procedure mpifx_allreduceip_with_id${TYPE_ABBREVS[TYPE]}$${RANK}$
#:endfor
#:endfor
end interface mpifx_allreduceip
Expand All @@ -109,7 +112,7 @@ contains
!!
!! See MPI documentation (mpi_allreduce()) for further details.
!!
subroutine mpifx_allreduce_${SUFFIX}$(mycomm, orig, reduced, reductionop, error)
subroutine mpifx_allreduce_with_type${SUFFIX}$(mycomm, orig, reduced, reductionop, error)

!> MPI communicator.
type(mpifx_comm), intent(in) :: mycomm
Expand All @@ -121,7 +124,7 @@ contains
${TYPE}$, intent(inout) :: reduced${RANKSUFFIX(RANK)}$

!> Reduction operator
integer, intent(in) :: reductionop
type(mpi_op), intent(in) :: reductionop

!> Error code on exit.
integer, intent(out), optional :: error
Expand All @@ -135,11 +138,36 @@ contains
#:set SIZE = '1' if RANK == 0 else 'size(orig)'
#:set COUNT = SIZE

call mpi_allreduce(orig, reduced, ${COUNT}$, ${MPITYPE}$, reductionop, mycomm%id, error0)
call mpi_allreduce(orig, reduced, ${COUNT}$, ${MPITYPE}$, reductionop, mycomm%comm, error0)
call handle_errorflag(error0, 'MPI_ALLREDUCE in mpifx_allreduce_${SUFFIX}$', error)

end subroutine mpifx_allreduce_${SUFFIX}$
end subroutine mpifx_allreduce_with_type${SUFFIX}$


subroutine mpifx_allreduce_with_id${SUFFIX}$(mycomm, orig, reduced, reductionop, error)

!> MPI communicator.
type(mpifx_comm), intent(in) :: mycomm

!> Quantity to be reduced.
${TYPE}$, intent(in) :: orig${RANKSUFFIX(RANK)}$

!> Contains result on exit.
${TYPE}$, intent(inout) :: reduced${RANKSUFFIX(RANK)}$

!> Reduction operator
integer, intent(in) :: reductionop

!> Error code on exit.
integer, intent(out), optional :: error

type(mpi_op) :: newop

newop%mpi_val = reductionop

call mpifx_allreduce(mycomm, orig, reduced, newop, error)

end subroutine mpifx_allreduce_with_id${SUFFIX}$
#:enddef mpifx_allreduce_template


Expand All @@ -151,7 +179,7 @@ contains
!!
!! See MPI documentation (mpi_allreduce()) for further details.
!!
subroutine mpifx_allreduceip_${SUFFIX}$(mycomm, origreduced, reductionop, error)
subroutine mpifx_allreduceip_with_type${SUFFIX}$(mycomm, origreduced, reductionop, error)

!> MPI communicator.
type(mpifx_comm), intent(in) :: mycomm
Expand All @@ -160,7 +188,7 @@ contains
${TYPE}$, intent(inout) :: origreduced${RANKSUFFIX(RANK)}$

!> Reduction operator.
integer, intent(in) :: reductionop
type(mpi_op), intent(in) :: reductionop

!> Error code on exit.
integer, intent(out), optional :: error
Expand All @@ -170,11 +198,34 @@ contains
#:set SIZE = '1' if RANK == 0 else 'size(origreduced)'
#:set COUNT = SIZE

call mpi_allreduce(MPI_IN_PLACE, origreduced, ${COUNT}$, ${MPITYPE}$, reductionop, mycomm%id,&
call mpi_allreduce(MPI_IN_PLACE, origreduced, ${COUNT}$, ${MPITYPE}$, reductionop, mycomm%comm,&
& error0)
call handle_errorflag(error0, "MPI_REDUCE in mpifx_allreduceip_${SUFFIX}$", error)

end subroutine mpifx_allreduceip_${SUFFIX}$
end subroutine mpifx_allreduceip_with_type${SUFFIX}$

subroutine mpifx_allreduceip_with_id${SUFFIX}$(mycomm, origreduced, reductionop, error)

!> MPI communicator.
type(mpifx_comm), intent(in) :: mycomm

!> Quantity to be reduced on input, reduced on exit.
${TYPE}$, intent(inout) :: origreduced${RANKSUFFIX(RANK)}$

!> Reduction operator.
integer, intent(in) :: reductionop

!> Error code on exit.
integer, intent(out), optional :: error

type(mpi_op) :: newop

newop%mpi_val = reductionop

call mpifx_allreduceip(mycomm, origreduced, newop, error)

end subroutine mpifx_allreduceip_with_id${SUFFIX}$


#:enddef mpifx_allreduceip_template

Expand Down
4 changes: 2 additions & 2 deletions lib/mpifx_barrier.fpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

!> Contains wrapper for \c MPI_BARRIER.
module mpifx_barrier_module
use mpi
use mpi_f08, only : mpi_barrier
use mpifx_comm_module, only : mpifx_comm
use mpifx_helper_module, only : handle_errorflag
implicit none
Expand Down Expand Up @@ -40,7 +40,7 @@ contains

integer :: error0

call mpi_barrier(mycomm%id, error0)
call mpi_barrier(mycomm%comm, error0)
call handle_errorflag(error0, "MPI_BARRIER in mpifx_barrier", error)

end subroutine mpifx_barrier
Expand Down
5 changes: 3 additions & 2 deletions lib/mpifx_bcast.fpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

!> Contains wrapper for \c MPI_BCAST.
module mpifx_bcast_module
use mpi
use mpi_f08, only : mpi_bcast, mpi_character, mpi_complex, mpi_double_complex,&
& mpi_double_precision, mpi_integer, mpi_logical, mpi_real
use mpifx_comm_module, only : mpifx_comm
use mpifx_helper_module, only : dp, getoptarg, handle_errorflag, sp
implicit none
Expand Down Expand Up @@ -75,7 +76,7 @@ contains
#:set COUNT = ('len(msg) * ' + SIZE if HASLENGTH else SIZE)

call getoptarg(mycomm%leadrank, root0, root)
call mpi_bcast(msg, ${COUNT}$, ${MPITYPE}$, root0, mycomm%id, error0)
call mpi_bcast(msg, ${COUNT}$, ${MPITYPE}$, root0, mycomm%comm, error0)
call handle_errorflag(error0, "MPI_BCAST in mpifx_bcast_${SUFFIX}$", error)

end subroutine mpifx_bcast_${SUFFIX}$
Expand Down
68 changes: 48 additions & 20 deletions lib/mpifx_comm.fpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
!> Contains the extended MPI communicator.
module mpifx_comm_module
use mpi
use mpi_f08, only : mpi_comm, mpi_comm_world, mpi_info_null
use mpifx_helper_module, only : getoptarg, handle_errorflag
implicit none
private
Expand All @@ -9,14 +9,19 @@ module mpifx_comm_module

!> MPI communicator with some additional information.
type mpifx_comm
integer :: id !< Communicator id.
integer :: size !< Nr. of processes (size).
integer :: rank !< Rank of the current process.
integer :: leadrank !< Index of the lead node.
logical :: lead !< True if current process is the lead (rank == 0).
integer :: id !< Communicator id.
type(mpi_comm) :: comm !< MPI communicator handle.
integer :: size !< Nr. of processes (size).
integer :: rank !< Rank of the current process.
integer :: leadrank !< Index of the lead node.
logical :: lead !< True if current process is the lead (rank == 0).
contains

!> Initializes the MPI environment.
procedure :: init => mpifx_comm_init
procedure, private :: mpifx_comm_init_int
procedure, private :: mpifx_comm_init_comm

generic :: init => mpifx_comm_init_int, mpifx_comm_init_comm

!> Creates a new communicator by splitting the old one.
procedure :: split => mpifx_comm_split
Expand All @@ -38,28 +43,48 @@ contains
!! \param error Error flag on return containing the first error occuring
!! during the calls mpi_comm_size and mpi_comm_rank.
!!
subroutine mpifx_comm_init(self, commid, error)
subroutine mpifx_comm_init_comm(self, comm, error)
class(mpifx_comm), intent(out) :: self
integer, intent(in), optional :: commid
type(mpi_comm), intent(in), optional :: comm
integer, intent(out), optional :: error

integer :: error0
type(mpi_comm) :: default_comm
default_comm = MPI_COMM_WORLD

call getoptarg(MPI_COMM_WORLD, self%id, commid)
call mpi_comm_size(self%id, self%size, error0)
if (present(comm)) then
self%comm = comm
else
self%comm = default_comm
end if
self%id = self%comm%mpi_val
call mpi_comm_size(self%comm, self%size, error0)
call handle_errorflag(error0, "mpi_comm_size() in mpifx_comm_init()", error)
if (error0 /= 0) then
return
end if
call mpi_comm_rank(self%id, self%rank, error0)
call mpi_comm_rank(self%comm, self%rank, error0)
call handle_errorflag(error0, "mpi_comm_rank() in mpifx_comm_init()", error)
if (error0 /= 0) then
return
end if
self%leadrank = 0
self%lead = (self%rank == self%leadrank)

end subroutine mpifx_comm_init
end subroutine mpifx_comm_init_comm


subroutine mpifx_comm_init_int(self, commid, error)
class(mpifx_comm), intent(out) :: self
integer, intent(in) :: commid
integer, intent(out), optional :: error

type(mpi_comm) :: newcomm

newcomm%mpi_val = commid
call self%mpifx_comm_init_comm(newcomm, error)

end subroutine mpifx_comm_init_int


!> Creates a new communicators by splitting the old one.
Expand Down Expand Up @@ -102,14 +127,15 @@ contains
class(mpifx_comm), intent(out) :: newcomm
integer, intent(out), optional :: error

integer :: error0, newcommid
integer :: error0
type(mpi_comm) :: newmpicomm

call mpi_comm_split(self%id, splitkey, rankkey, newcommid, error0)
call mpi_comm_split(self%comm, splitkey, rankkey, newmpicomm, error0)
call handle_errorflag(error0, "mpi_comm_split() in mpifx_comm_split()", error)
if (error0 /= 0) then
return
end if
call newcomm%init(newcommid, error)
call newcomm%init(newmpicomm, error)

end subroutine mpifx_comm_split

Expand Down Expand Up @@ -150,14 +176,15 @@ contains
class(mpifx_comm), intent(out) :: newcomm
integer, intent(out), optional :: error

integer :: error0, newcommid
integer :: error0
type(mpi_comm) :: newmpicomm

call mpi_comm_split_type(self%id, splittype, rankkey, MPI_INFO_NULL, newcommid, error0)
call mpi_comm_split_type(self%comm, splittype, rankkey, MPI_INFO_NULL, newmpicomm, error0)
call handle_errorflag(error0, "mpi_comm_split_type() in mpifx_comm_split_type()", error)
if (error0 /= 0) then
return
end if
call newcomm%init(newcommid, error)
call newcomm%init(newmpicomm, error)

end subroutine mpifx_comm_split_type

Expand All @@ -173,7 +200,8 @@ contains

integer :: error

call mpi_comm_free(self%id, error)
call mpi_comm_free(self%comm, error)
self%id = self%comm%mpi_val

end subroutine mpifx_comm_free

Expand Down
2 changes: 1 addition & 1 deletion lib/mpifx_constants.fpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
!> Exports some MPI constants.
!! \cond HIDDEN
module mpifx_constants_module
use mpi
use mpi_f08, only : MPI_ADDRESS_KIND
private

public :: MPI_MAX, MPI_MIN, MPI_SUM, MPI_PROD
Expand Down
Loading