diff --git a/pshmem/locking.py b/pshmem/locking.py index 1e512e1..148d793 100644 --- a/pshmem/locking.py +++ b/pshmem/locking.py @@ -1,5 +1,5 @@ ## -# Copyright (c) 2017-2020, all rights reserved. Use of this source code +# Copyright (c) 2017-2024, all rights reserved. Use of this source code # is governed by a BSD license that can be found in the top-level # LICENSE file. ## diff --git a/pshmem/shmem.py b/pshmem/shmem.py index fbc97fd..383da34 100644 --- a/pshmem/shmem.py +++ b/pshmem/shmem.py @@ -1,17 +1,15 @@ ## -# Copyright (c) 2017-2020, all rights reserved. Use of this source code +# Copyright (c) 2017-2024, all rights reserved. Use of this source code # is governed by a BSD license that can be found in the top-level # LICENSE file. ## import sys -import mmap -import uuid import numpy as np -import posix_ipc +import sysv_ipc -from .utils import mpi_data_type +from .utils import mpi_data_type, random_shm_key class MPIShared(object): @@ -147,16 +145,19 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): # and a unique random ID. self._name = None + self._shm_index = None if self._rank == 0: - rng_str = uuid.uuid4().hex[:12] - self._name = f"MPIShared_{rng_str}" + # Get a random 64bit integer between the supported range of keys + self._shm_index = random_shm_key() + # Name, just used for printing + self._name = f"MPIShared_{self._shm_index}" if self._comm is not None: + self._shm_index = self._comm.bcast(self._shm_index, root=0) self._name = self._comm.bcast(self._name, root=0) # Only allocate our buffers if the total number of elements is > 0 self._shmem = None - self._shmap = None self._flat = None self.data = None @@ -176,9 +177,9 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): # First rank on each node creates the buffer if self._noderank == 0: try: - self._shmem = posix_ipc.SharedMemory( - self._name, - posix_ipc.O_CREX, + self._shmem = sysv_ipc.SharedMemory( + self._shm_index, + flags=sysv_ipc.IPC_CREX, size=int(nbytes), ) except Exception as e: @@ -190,27 +191,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): msg += ": {}".format(e) print(msg, flush=True) raise - try: - # MMap the shared memory - self._shmap = mmap.mmap( - self._shmem.fd, - self._shmem.size, - ) - except Exception as e: - msg = "Process {}: {}".format(self._rank, self._name) - msg += " failed MMap of {} bytes".format(nbytes) - msg += " ({} elements of {} bytes each)".format( - self._n, self._dsize - ) - msg += ": {}".format(e) - print(msg, flush=True) - # Try to free the shared memory object - try: - self._shmem.close_fd() - self._shmem.unlink() - except Exception as eclose: - pass - raise # Wait for that to be created if self._nodecomm is not None: @@ -219,11 +199,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): # Other ranks on the node attach if self._noderank != 0: try: - self._shmem = posix_ipc.SharedMemory(self._name) - # MMap the shared memory - self._shmap = mmap.mmap( - self._shmem.fd, - self._shmem.size, + self._shmem = sysv_ipc.SharedMemory( + self._shm_index, flags=0, size=0 ) except Exception as e: msg = "Process {}: {}".format(self._rank, self._name) @@ -239,22 +216,15 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): if self._nodecomm is not None: self._nodecomm.barrier() - # Now that all processes have mmap'ed the shared memory we can - # close the shared memory handle - self._shmem.close_fd() - - # Wait for all processes to close file handle - if self._nodecomm is not None: - self._nodecomm.barrier() - - # One process requests the file to be deleted, but this will not - # actually happen until all processes release their mmap. + # Now the rank zero process will call remove() to mark the shared + # memory segment for removal. However, this will not actually + # be removed until all processes detach. if self._noderank == 0: try: - self._shmem.unlink() - except posix_ipc.ExistentialError: + self._shmem.remove() + except sysv_ipc.ExistentialError: msg = "Process {}: {}".format(self._rank, self._name) - msg += " failed to unlink shared memory" + msg += " failed to remove shared memory" msg += ": {}".format(e) print(msg, flush=True) raise @@ -263,7 +233,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): self._flat = np.ndarray( self._n, dtype=self._dtype, - buffer=self._shmap, + buffer=self._shmem, ) # Initialize to zero. if self._noderank == 0: @@ -272,8 +242,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): # Wrap self.data = self._flat.reshape(self._shape) - - def __del__(self): self.close() @@ -399,17 +367,11 @@ def close(self): del self.data if hasattr(self, "_flat"): del self._flat - if hasattr(self, "_shmap"): - # Close the mmap'ed memory - if self._shmap is not None: - self._shmap.close() - del self._shmap - self._shmap = None if hasattr(self, "_shmem"): if self._shmem is not None: + self._shmem.detach() del self._shmem self._shmem = None - self._flat = None self.data = None diff --git a/pshmem/test.py b/pshmem/test.py index 2c814df..ffdf4e0 100644 --- a/pshmem/test.py +++ b/pshmem/test.py @@ -1,5 +1,5 @@ ## -# Copyright (c) 2017-2020, all rights reserved. Use of this source code +# Copyright (c) 2017-2024, all rights reserved. Use of this source code # is governed by a BSD license that can be found in the top-level # LICENSE file. ## @@ -425,6 +425,19 @@ def test_zero(self): except RuntimeError: print("successful raise with no data during set()", flush=True) + # def test_hang(self): + # # Run this while monitoring memory usage (e.g. with htop) and then + # # do kill -9 on one of the processes to verify that the kernel + # # releases shared memory. + # dims = (200, 1000000) + # dt = np.float64 + # shm = MPIShared(dims, dt, self.comm) + # import time + # time.sleep(60) + # shm.close() + # del shm + # return + class LockTest(unittest.TestCase): def setUp(self): diff --git a/pshmem/utils.py b/pshmem/utils.py index c527c3d..a999a9c 100644 --- a/pshmem/utils.py +++ b/pshmem/utils.py @@ -1,10 +1,13 @@ ## -# Copyright (c) 2017-2020, all rights reserved. Use of this source code +# Copyright (c) 2017-2024, all rights reserved. Use of this source code # is governed by a BSD license that can be found in the top-level # LICENSE file. ## +import random + import numpy as np +import sysv_ipc def mpi_data_type(comm, dt): @@ -42,3 +45,20 @@ def mpi_data_type(comm, dt): raise dsize = mpitype.Get_size() return (dsize, mpitype) + + +def random_shm_key(): + """Get a random 64bit integer in the range supported by shmget() + + The python random library is used, and seeded with the default source + (either system time or os.urandom). + + Returns: + (int): The random integer. + + """ + min_val = sysv_ipc.KEY_MIN + max_val = sysv_ipc.KEY_MAX + # Seed with default source of randomness + random.seed(a=None) + return random.randint(min_val, max_val) diff --git a/setup.py b/setup.py index b581372..7a54cd6 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def readme(): scripts=None, license="BSD", python_requires=">=3.8.0", - install_requires=["numpy", "posix_ipc"], + install_requires=["numpy", "sysv_ipc"], extras_require={"mpi": ["mpi4py>=3.0"]}, cmdclass=versioneer.get_cmdclass(), classifiers=[