Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Python SharedMemory as the backend #24

Merged
merged 2 commits into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 19 additions & 24 deletions pshmem/shmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@
##

import sys
from multiprocessing import shared_memory

import numpy as np
import sysv_ipc

from .utils import mpi_data_type, random_shm_key
from .utils import (
mpi_data_type,
random_shm_key,
remove_shm_from_resource_tracker,
)

# Monkey patch resource_tracker. Remove once upstream CPython
# changes are merged.
remove_shm_from_resource_tracker()


class MPIShared(object):
Expand Down Expand Up @@ -149,7 +157,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
if self._rank == 0:
# Get a random 64bit integer between the supported range of keys
self._shm_index = random_shm_key()
# Name, just used for printing
# Name, used as global tag.
self._name = f"MPIShared_{self._shm_index}"
if self._comm is not None:
self._shm_index = self._comm.bcast(self._shm_index, root=0)
Expand Down Expand Up @@ -177,10 +185,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# First rank on each node creates the buffer
if self._noderank == 0:
try:
self._shmem = sysv_ipc.SharedMemory(
self._shm_index,
flags=sysv_ipc.IPC_CREX,
size=int(nbytes),
self._shmem = shared_memory.SharedMemory(
name=self._name, create=True, size=int(nbytes),
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
Expand All @@ -199,8 +205,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# Other ranks on the node attach
if self._noderank != 0:
try:
self._shmem = sysv_ipc.SharedMemory(
self._shm_index, flags=0, size=0
self._shmem = shared_memory.SharedMemory(
name=self._name, create=False, size=int(nbytes)
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
Expand All @@ -216,7 +222,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
self._flat = np.ndarray(
self._n,
dtype=self._dtype,
buffer=self._shmem,
buffer=self._shmem.buf,
)

# Initialize to zero.
Expand All @@ -230,19 +236,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
if self._nodecomm is not None:
self._nodecomm.barrier()

# Now the rank zero process will call remove() to mark the shared
# memory segment for removal. However, this will not actually
# be removed until all processes detach.
if self._noderank == 0:
try:
self._shmem.remove()
except sysv_ipc.ExistentialError:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed to remove shared memory"
msg += ": {}".format(e)
print(msg, flush=True)
raise

def __del__(self):
self.close()

Expand Down Expand Up @@ -370,7 +363,9 @@ def close(self):
del self._flat
if hasattr(self, "_shmem"):
if self._shmem is not None:
self._shmem.detach()
self._shmem.close()
if self._noderank == 0:
self._shmem.unlink()
del self._shmem
self._shmem = None
self._flat = None
Expand Down
7 changes: 6 additions & 1 deletion pshmem/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,14 @@ def test_zero(self):
# dims = (200, 1000000)
# dt = np.float64
# shm = MPIShared(dims, dt, self.comm)
# if self.comm is None or self.comm.rank == 0:
# temp = np.ones(dims, dtype=dt)
# else:
# temp = None
# shm.set(temp, fromrank=0)
# del temp
# import time
# time.sleep(60)
# shm.close()
# del shm
# return

Expand Down
32 changes: 28 additions & 4 deletions pshmem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
##

import random
import sys
# Import for monkey patching resource tracker
from multiprocessing import resource_tracker

import numpy as np
import sysv_ipc


def mpi_data_type(comm, dt):
Expand Down Expand Up @@ -48,7 +50,7 @@ def mpi_data_type(comm, dt):


def random_shm_key():
"""Get a random 64bit integer in the range supported by shmget()
"""Get a random positive integer for using in shared memory naming.

The python random library is used, and seeded with the default source
(either system time or os.urandom).
Expand All @@ -57,8 +59,30 @@ def random_shm_key():
(int): The random integer.

"""
min_val = sysv_ipc.KEY_MIN
max_val = sysv_ipc.KEY_MAX
min_val = 0
max_val = sys.maxsize
# Seed with default source of randomness
random.seed(a=None)
return random.randint(min_val, max_val)


def remove_shm_from_resource_tracker():
"""Monkey-patch multiprocessing.resource_tracker so SharedMemory won't be tracked

More details at: https://bugs.python.org/issue38119
"""

def fix_register(name, rtype):
if rtype == "shared_memory":
return
return resource_tracker._resource_tracker.register(self, name, rtype)
resource_tracker.register = fix_register

def fix_unregister(name, rtype):
if rtype == "shared_memory":
return
return resource_tracker._resource_tracker.unregister(self, name, rtype)
resource_tracker.unregister = fix_unregister

if "shared_memory" in resource_tracker._CLEANUP_FUNCS:
del resource_tracker._CLEANUP_FUNCS["shared_memory"]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def readme():
scripts=None,
license="BSD",
python_requires=">=3.8.0",
install_requires=["numpy", "sysv_ipc"],
install_requires=["numpy"],
extras_require={"mpi": ["mpi4py>=3.0"]},
cmdclass=versioneer.get_cmdclass(),
classifiers=[
Expand Down