From f557f85930b391e71be530c27c8df38afaf51ffc Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 20 May 2025 00:31:52 +0000 Subject: [PATCH] Avoid searching unnecessary dirs for shared libs Signed-off-by: Tim Moon --- transformer_engine/common/__init__.py | 142 +++++++++++--------------- 1 file changed, 57 insertions(+), 85 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 835a74389b..8b89c54d9a 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -4,25 +4,26 @@ """FW agnostic user-end APIs""" -import sys -import glob -import sysconfig -import subprocess import ctypes +import functools +import glob +import importlib +from importlib.metadata import version, metadata, PackageNotFoundError import logging import os -import platform -import importlib -import functools from pathlib import Path -from importlib.metadata import version, metadata, PackageNotFoundError +import platform +import subprocess +import sys +import sysconfig +from typing import Optional _logger = logging.getLogger(__name__) @functools.lru_cache(maxsize=None) -def _is_pip_package_installed(package): +def _is_pip_package_installed(package) -> bool: """Check if the given package is installed via pip.""" # This is needed because we only want to return true @@ -37,21 +38,21 @@ def _is_pip_package_installed(package): @functools.lru_cache(maxsize=None) -def _find_shared_object_in_te_dir(te_path: Path, prefix: str): +def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]: """ - Find a shared object file of given prefix in the top level TE directory. - Only the following locations are searched to avoid stray SOs and build - artifacts: - 1. The given top level directory (editable install). - 2. `transformer_engine` named directories (source install). - 3. `wheel_lib` named directories (PyPI install). + Find a shared object file with the given prefix within the top level TE directory. + + The following locations are searched: + 1. Top level directory (editable install). + 2. `transformer_engine` directory (source install). + 3. `wheel_lib` directory (PyPI install). Returns None if no shared object files are found. Raises an error if multiple shared object files are found. """ - # Ensure top level dir exists and has the module. before searching. - if not te_path.exists() or not (te_path / "transformer_engine").exists(): + # Ensure top level dir exists and has the module before searching. + if not te_path.is_dir() or not (te_path / "transformer_engine").exists(): return None files = [] @@ -63,11 +64,12 @@ def _find_shared_object_in_te_dir(te_path: Path, prefix: str): ) # Search. - for dirname, _, names in os.walk(te_path): - if Path(dirname) in search_paths: - for name in names: - if name.startswith(prefix) and name.endswith(f".{_get_sys_extension()}"): - files.append(Path(dirname, name)) + for dir_path in search_paths: + if not dir_path.is_dir(): + continue + for file_path in dir_path.iterdir(): + if file_path.name.startswith(prefix) and file_path.suffix == _get_sys_extension(): + files.append(file_path) if len(files) == 0: return None @@ -79,16 +81,12 @@ def _find_shared_object_in_te_dir(te_path: Path, prefix: str): @functools.lru_cache(maxsize=None) def _get_shared_object_file(library: str) -> Path: """ - Return the path of the shared object file for the given TE - library, one of 'core', 'torch', or 'jax'. - - Several factors affect finding the correct location of the shared object: - 1. System and environment. - 2. If the installation is from source or via PyPI. - - Source installed .sos are placed in top level dir - - Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts. - 3. For source installations, is the install editable/inplace? - 4. The user directory from where TE is being imported. + Path to shared object file for a Transformer Engine library. + + TE libraries are 'core', 'torch', or 'jax'. This function first + searches in the imported TE directory, and then in the + site-packages directory. + """ # Check provided input and determine the correct prefix for .so. @@ -98,47 +96,23 @@ def _get_shared_object_file(library: str) -> Path: else: so_prefix = f"transformer_engine_{library}" - # Check TE install location (will be local if TE is available in current dir for import). - te_install_dir = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent - so_path_in_install_dir = _find_shared_object_in_te_dir(te_install_dir, so_prefix) + # Search for shared lib in imported directory + te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent + so_path = _find_shared_object_in_te_dir(te_path, so_prefix) + if so_path is not None: + return so_path - # Check default python package install location in system. - site_packages_dir = Path(sysconfig.get_paths()["purelib"]) - so_path_in_default_dir = _find_shared_object_in_te_dir(site_packages_dir, so_prefix) - - # Case 1: Typical user workflow: Both locations are the same, return any result. - if te_install_dir == site_packages_dir: - assert ( - so_path_in_install_dir is not None - ), f"Could not find shared object file for Transformer Engine {library} lib." - return so_path_in_install_dir - - # Case 2: ERR! Both locations are different but returned a valid result. - # NOTE: Unlike for source installations, pip does not wipe out artifacts from - # editable builds. In case developers are executing inside a TE directory via - # an inplace build, and then move to a regular build, the local shared object - # file will be incorrectly picked up without the following logic. - if so_path_in_install_dir is not None and so_path_in_default_dir is not None: - raise RuntimeError( - f"Found multiple shared object files: {so_path_in_install_dir} and" - f" {so_path_in_default_dir}. Remove local shared objects installed" - f" here {so_path_in_install_dir} or change the working directory to" - "execute from outside TE." - ) - - # Case 3: Typical dev workflow: Editable install - if so_path_in_install_dir is not None: - return so_path_in_install_dir - - # Case 4: Executing from inside a TE directory without an inplace build available. - if so_path_in_default_dir is not None: - return so_path_in_default_dir + # Search for shared lib in site-packages directory + te_path = Path(sysconfig.get_paths()["purelib"]) + so_path = _find_shared_object_in_te_dir(te_path, so_prefix) + if so_path is not None: + return so_path raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.") @functools.lru_cache(maxsize=None) -def load_framework_extension(framework: str): +def load_framework_extension(framework: str) -> None: """ Load shared library with Transformer Engine framework bindings and check verify correctness if installed via PyPI. @@ -196,18 +170,16 @@ def load_framework_extension(framework: str): @functools.lru_cache(maxsize=None) -def _get_sys_extension(): +def _get_sys_extension() -> str: + """File extension for shared objects.""" system = platform.system() if system == "Linux": - extension = "so" - elif system == "Darwin": - extension = "dylib" - elif system == "Windows": - extension = "dll" - else: - raise RuntimeError(f"Unsupported operating system ({system})") - - return extension + return ".so" + if system == "Darwin": + return ".dylib" + if system == "Windows": + return ".dll" + raise RuntimeError(f"Unsupported operating system ({system})") @functools.lru_cache(maxsize=None) @@ -221,7 +193,7 @@ def _load_nvidia_cuda_library(lib_name: str): so_paths = glob.glob( os.path.join( sysconfig.get_path("purelib"), - f"nvidia/{lib_name}/lib/lib*.{_get_sys_extension()}.*[0-9]", + f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]", ) ) @@ -236,7 +208,7 @@ def _load_nvidia_cuda_library(lib_name: str): @functools.lru_cache(maxsize=None) -def _nvidia_cudart_include_dir(): +def _nvidia_cudart_include_dir() -> str: """Returns the include directory for cuda_runtime.h if exists in python environment.""" try: @@ -255,14 +227,14 @@ def _load_cudnn(): # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") if cudnn_home: - libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) + libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True) libs.sort(reverse=True, key=os.path.basename) if libs: return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" - libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) + libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True) libs.sort(reverse=True, key=os.path.basename) if libs: return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) @@ -273,7 +245,7 @@ def _load_cudnn(): return handle # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise - return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) @functools.lru_cache(maxsize=None) @@ -281,7 +253,7 @@ def _load_nvrtc(): """Load NVRTC shared library.""" # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" - libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True) + libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True) libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) libs.sort(reverse=True, key=os.path.basename) if libs: @@ -305,7 +277,7 @@ def _load_nvrtc(): return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise - return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) @functools.lru_cache(maxsize=None)