From fdcea9b7ec6e647e42224e606b89708b0394177d Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 1 Nov 2024 13:27:25 -0700 Subject: [PATCH] Fix shared memory permission issue in a shared pod environment (#813) --- streaming/base/dataset.py | 2 +- streaming/base/shared/prefix.py | 70 +++++++++++++++++++++++++-------- streaming/base/util.py | 9 +++-- tests/test_shared.py | 33 ++++++++++++++++ 4 files changed, 93 insertions(+), 21 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 774503682..2ce9f6e48 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -528,7 +528,7 @@ def __init__(self, ] self._shm_prefix_int, self._locals_shm = get_shm_prefix(streams_local, streams_remote, self._unique_rank_world) - self._filelock_root = os.path.join(gettempdir(), 'streaming') + self._filelock_root = gettempdir() os.makedirs(self._filelock_root, exist_ok=True) # Create the shared memory-backed barrier, without its lock, which is unpickleable. diff --git a/streaming/base/shared/prefix.py b/streaming/base/shared/prefix.py index 56f2cc6f9..b25743df2 100644 --- a/streaming/base/shared/prefix.py +++ b/streaming/base/shared/prefix.py @@ -7,14 +7,16 @@ prevent shared resources like shared memory from colliding. """ +import os from collections import Counter +from tempfile import gettempdir from time import sleep from typing import Iterator, Union import numpy as np from torch import distributed as dist -from streaming.base.constant import LOCALS, TICK +from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK, LOCALS, SHM_TO_CLEAN, TICK from streaming.base.shared import SharedMemory from streaming.base.world import World @@ -91,7 +93,8 @@ def _check_self(streams_local: list[str]) -> None: f'Reused local directory: {duplicate_local_dirs}. Provide a different one.') -def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, None]]) -> int: +def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, None]], + shm_name: str) -> int: """Find the next available prefix while checking existing local dirs for overlap. Local leader walks the existing shm prefixes starting from zero, verifying that there is no @@ -101,18 +104,40 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No Args: streams_local (List[str]): Our local working directories. streams_remote (List[Union[str, None]]): Our remote working directories. + shm_name (str): The shared memory file name, e.g., LOCALS, BARRIER etc. Returns: int: Next available prefix int. """ prefix_int = 0 + for prefix_int in _each_prefix_int(): - name = _get_path(prefix_int, LOCALS) + + name = _get_path(prefix_int, shm_name) + + # Check if any shared memory filelocks exist for the current prefix + try: + filelock_exists = any( + os.path.exists(os.path.join(gettempdir(), _get_path(prefix_int, filelock_name))) + for filelock_name in [BARRIER_FILELOCK, CACHE_FILELOCK]) + if filelock_exists: + continue + except PermissionError: + continue + + # Attempt to access shared memory by name. Use prefix_int if files do not exist try: shm = SharedMemory(name, False) + except PermissionError: + continue except FileNotFoundError: break + + if shm_name != LOCALS: + continue + their_locals, _ = _unpack_locals(bytes(shm.buf)) + # Do not check for a conflicting local directories across existing shared memory if # remote directories are None. Get the next prefix. if any(streams_remote): @@ -135,7 +160,7 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Union[str, None]], - retry: int) -> int: + shm_name: str, retry: int) -> int: """Find the next available prefix while checking existing dirs for overlap. If an overlap is found, sleeps for a tick and then tries again, up to "retry" times. We allow @@ -145,6 +170,7 @@ def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Unio Args: streams_local (List[str]): Our local working directories. streams_remote (List[Union[str, None]]): Our remote working directories. + shm_name (str): The shared memory file name, e.g., LOCALS, BARRIER etc. retry (int): Number of retries upon failure before raising an exception. Returns: @@ -155,7 +181,7 @@ def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Unio errs = [] for _ in range(1 + retry): try: - return _check_and_find(streams_local, streams_remote) + return _check_and_find(streams_local, streams_remote, shm_name) except ValueError as err: errs.append(err) sleep(TICK) @@ -184,9 +210,16 @@ def get_shm_prefix(streams_local: list[str], # Check my locals for overlap. _check_self(streams_local) + prefix_int = max([ + _check_and_find_retrying(streams_local, streams_remote, shm_name=shm_name, retry=retry) + for shm_name in SHM_TO_CLEAN + ]) + + if dist.is_available() and dist.is_initialized(): + dist.barrier() + # First, the local leader registers the first available shm prefix, recording its locals. if world.is_local_leader: - prefix_int = _check_and_find_retrying(streams_local, streams_remote, retry) name = _get_path(prefix_int, LOCALS) data = _pack_locals(streams_local, prefix_int) shm = SharedMemory(name, True, len(data)) @@ -197,15 +230,18 @@ def get_shm_prefix(streams_local: list[str], # Non-local leaders go next, searching for match. if not world.is_local_leader: - for prefix_int in _each_prefix_int(): - name = _get_path(prefix_int, LOCALS) - try: - shm = SharedMemory(name, False) - except FileNotFoundError: - raise RuntimeError(f'Internal error: shared memory prefix was not registered by ' + - f'local leader') - their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf)) - if streams_local == their_locals and prefix_int == their_prefix_int: - break - + name = _get_path(prefix_int, LOCALS) + try: + shm = SharedMemory(name, False) + except FileNotFoundError: + raise RuntimeError(f'Internal error: shared memory prefix={prefix_int} was not ' + + f'registered by local leader. This may be because you specified ' + + f'different ``local`` parameters from different ranks.') + + their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf)) + if streams_local != their_locals or prefix_int != their_prefix_int: + raise RuntimeError(f'Internal error: shared memory registered does not match ' + + f'local leader as streams_local or prefix_int not match. ' + + f'local leader: {their_locals} and {their_prefix_int}. ' + + f'expected: {streams_local} and {prefix_int}.') return prefix_int, shm # pyright: ignore diff --git a/streaming/base/util.py b/streaming/base/util.py index f802e644d..68ea1c93e 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -184,9 +184,12 @@ def clean_stale_shared_memory() -> None: try: shm = BuiltinSharedMemory(name, True, 4) except FileExistsError: - shm = BuiltinSharedMemory(name, False, 4) - leaked_shm = True - finally: + try: + shm = BuiltinSharedMemory(name, False, 4) + leaked_shm = True + except PermissionError: + continue + if shm: shm.close() # pyright: ignore shm.unlink() # Come out of loop if no leaked shared memory diff --git a/tests/test_shared.py b/tests/test_shared.py index f5016f1eb..d1914617d 100644 --- a/tests/test_shared.py +++ b/tests/test_shared.py @@ -1,13 +1,19 @@ # Copyright 2022-2024 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 +import os +import tempfile from unittest.mock import MagicMock, patch import numpy as np import pytest from streaming.base import StreamingDataset +from streaming.base.constant import LOCALS from streaming.base.shared import SharedArray, get_shm_prefix +from streaming.base.shared.memory import SharedMemory +from streaming.base.shared.prefix import _check_and_find +from streaming.base.util import clean_stale_shared_memory from streaming.base.world import World from tests.common.utils import convert_to_mds @@ -157,3 +163,30 @@ def test_shared_array_size_is_integer(mock_shared_memory: MagicMock, dtype: type mock_shared_memory.assert_called_once() # pyright: ignore size_arg = mock_shared_memory.call_args[1]['size'] assert isinstance(size_arg, int), 'Size passed to SharedMemory is not an integer' + + +def test_check_and_find_skips_filelock_conflict(): + """Test _check_and_find skips prefix due to file lock conflict.""" + clean_stale_shared_memory() + + with patch('os.path.exists') as mock_exists, \ + patch('multiprocessing.shared_memory.SharedMemory', side_effect=FileNotFoundError): + # Simulate that `/000000.barrier_filelock` exists, indicating a lock conflict + bf_path = os.path.join(tempfile.gettempdir(), '000000_barrier_filelock') + mock_exists.side_effect = lambda path: path == bf_path + + # Expect _check_and_find to return 1 as the next available prefix + next_prefix = _check_and_find(['local_dir'], [None], LOCALS) + assert next_prefix == 1 + + +@patch.object(SharedMemory, + '__init__', + side_effect=[ + PermissionError('Mocked permission error'), + FileNotFoundError('Mocked file not found error') + ]) +def test_shared_memory_permission_error(mock_shared_memory_class: MagicMock): + with patch('os.path.exists', return_value=False): + next_prefix = _check_and_find(['local'], [None], LOCALS) + assert next_prefix == 1