diff --git a/pshmem/shmem.py b/pshmem/shmem.py index 383da34..41ffb57 100644 --- a/pshmem/shmem.py +++ b/pshmem/shmem.py @@ -212,7 +212,21 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): print(msg, flush=True) raise - # Wait for other processes to attach + # Create a numpy array which acts as a view of the buffer. + self._flat = np.ndarray( + self._n, + dtype=self._dtype, + buffer=self._shmem, + ) + + # Initialize to zero. + if self._noderank == 0: + self._flat[:] = 0 + + # Wrap + self.data = self._flat.reshape(self._shape) + + # Wait for other processes to attach and wrap if self._nodecomm is not None: self._nodecomm.barrier() @@ -229,19 +243,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): print(msg, flush=True) raise - # Create a numpy array which acts as a view of the buffer. - self._flat = np.ndarray( - self._n, - dtype=self._dtype, - buffer=self._shmem, - ) - # Initialize to zero. - if self._noderank == 0: - self._flat[:] = 0 - - # Wrap - self.data = self._flat.reshape(self._shape) - def __del__(self): self.close() diff --git a/pshmem/test.py b/pshmem/test.py index ffdf4e0..1797376 100644 --- a/pshmem/test.py +++ b/pshmem/test.py @@ -467,9 +467,25 @@ def test_lock(self): def run(): + comm = None + if MPI is not None: + comm = MPI.COMM_WORLD + suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(LockTest)) suite.addTest(unittest.makeSuite(ShmemTest)) runner = unittest.TextTestRunner() - runner.run(suite) + + ret = 0 + _ret = runner.run(suite) + if not _ret.wasSuccessful(): + ret += 1 + + if comm is not None: + ret = comm.allreduce(ret, op=MPI.SUM) + + if ret > 0: + print(f"{ret} Processes had failures") + sys.exit(6) + return