Skip to content

Avoid searching unnecessary dirs for shared libs #1801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 57 additions & 85 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,26 @@

"""FW agnostic user-end APIs"""

import sys
import glob
import sysconfig
import subprocess
import ctypes
import functools
import glob
import importlib
from importlib.metadata import version, metadata, PackageNotFoundError
import logging
import os
import platform
import importlib
import functools
from pathlib import Path
from importlib.metadata import version, metadata, PackageNotFoundError
import platform
import subprocess
import sys
import sysconfig
from typing import Optional


_logger = logging.getLogger(__name__)


@functools.lru_cache(maxsize=None)
def _is_pip_package_installed(package):
def _is_pip_package_installed(package) -> bool:
"""Check if the given package is installed via pip."""

# This is needed because we only want to return true
Expand All @@ -37,21 +38,21 @@ def _is_pip_package_installed(package):


@functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]:
"""
Find a shared object file of given prefix in the top level TE directory.
Only the following locations are searched to avoid stray SOs and build
artifacts:
1. The given top level directory (editable install).
2. `transformer_engine` named directories (source install).
3. `wheel_lib` named directories (PyPI install).
Find a shared object file with the given prefix within the top level TE directory.

The following locations are searched:
1. Top level directory (editable install).
2. `transformer_engine` directory (source install).
3. `wheel_lib` directory (PyPI install).

Returns None if no shared object files are found.
Raises an error if multiple shared object files are found.
"""

# Ensure top level dir exists and has the module. before searching.
if not te_path.exists() or not (te_path / "transformer_engine").exists():
# Ensure top level dir exists and has the module before searching.
if not te_path.is_dir() or not (te_path / "transformer_engine").exists():
return None

files = []
Expand All @@ -63,11 +64,12 @@ def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
)

# Search.
for dirname, _, names in os.walk(te_path):
if Path(dirname) in search_paths:
for name in names:
if name.startswith(prefix) and name.endswith(f".{_get_sys_extension()}"):
files.append(Path(dirname, name))
for dir_path in search_paths:
if not dir_path.is_dir():
continue
for file_path in dir_path.iterdir():
if file_path.name.startswith(prefix) and file_path.suffix == _get_sys_extension():
files.append(file_path)

if len(files) == 0:
return None
Expand All @@ -79,16 +81,12 @@ def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
@functools.lru_cache(maxsize=None)
def _get_shared_object_file(library: str) -> Path:
"""
Return the path of the shared object file for the given TE
library, one of 'core', 'torch', or 'jax'.

Several factors affect finding the correct location of the shared object:
1. System and environment.
2. If the installation is from source or via PyPI.
- Source installed .sos are placed in top level dir
- Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts.
3. For source installations, is the install editable/inplace?
4. The user directory from where TE is being imported.
Path to shared object file for a Transformer Engine library.

TE libraries are 'core', 'torch', or 'jax'. This function first
searches in the imported TE directory, and then in the
site-packages directory.

"""

# Check provided input and determine the correct prefix for .so.
Expand All @@ -98,47 +96,23 @@ def _get_shared_object_file(library: str) -> Path:
else:
so_prefix = f"transformer_engine_{library}"

# Check TE install location (will be local if TE is available in current dir for import).
te_install_dir = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
so_path_in_install_dir = _find_shared_object_in_te_dir(te_install_dir, so_prefix)
# Search for shared lib in imported directory
te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
so_path = _find_shared_object_in_te_dir(te_path, so_prefix)
if so_path is not None:
return so_path

# Check default python package install location in system.
site_packages_dir = Path(sysconfig.get_paths()["purelib"])
so_path_in_default_dir = _find_shared_object_in_te_dir(site_packages_dir, so_prefix)

# Case 1: Typical user workflow: Both locations are the same, return any result.
if te_install_dir == site_packages_dir:
assert (
so_path_in_install_dir is not None
), f"Could not find shared object file for Transformer Engine {library} lib."
return so_path_in_install_dir

# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if so_path_in_install_dir is not None and so_path_in_default_dir is not None:
raise RuntimeError(
f"Found multiple shared object files: {so_path_in_install_dir} and"
f" {so_path_in_default_dir}. Remove local shared objects installed"
f" here {so_path_in_install_dir} or change the working directory to"
"execute from outside TE."
)

# Case 3: Typical dev workflow: Editable install
if so_path_in_install_dir is not None:
return so_path_in_install_dir

# Case 4: Executing from inside a TE directory without an inplace build available.
if so_path_in_default_dir is not None:
return so_path_in_default_dir
# Search for shared lib in site-packages directory
te_path = Path(sysconfig.get_paths()["purelib"])
so_path = _find_shared_object_in_te_dir(te_path, so_prefix)
if so_path is not None:
return so_path

raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.")


@functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str):
def load_framework_extension(framework: str) -> None:
"""
Load shared library with Transformer Engine framework bindings
and check verify correctness if installed via PyPI.
Expand Down Expand Up @@ -196,18 +170,16 @@ def load_framework_extension(framework: str):


@functools.lru_cache(maxsize=None)
def _get_sys_extension():
def _get_sys_extension() -> str:
"""File extension for shared objects."""
system = platform.system()
if system == "Linux":
extension = "so"
elif system == "Darwin":
extension = "dylib"
elif system == "Windows":
extension = "dll"
else:
raise RuntimeError(f"Unsupported operating system ({system})")

return extension
return ".so"
if system == "Darwin":
return ".dylib"
if system == "Windows":
return ".dll"
raise RuntimeError(f"Unsupported operating system ({system})")


@functools.lru_cache(maxsize=None)
Expand All @@ -221,7 +193,7 @@ def _load_nvidia_cuda_library(lib_name: str):
so_paths = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/{lib_name}/lib/lib*.{_get_sys_extension()}.*[0-9]",
f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]",
)
)

Expand All @@ -236,7 +208,7 @@ def _load_nvidia_cuda_library(lib_name: str):


@functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir():
def _nvidia_cudart_include_dir() -> str:
"""Returns the include directory for cuda_runtime.h if exists in python environment."""

try:
Expand All @@ -255,14 +227,14 @@ def _load_cudnn():
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
Expand All @@ -273,15 +245,15 @@ def _load_cudnn():
return handle

# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)


@functools.lru_cache(maxsize=None)
def _load_nvrtc():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
Expand All @@ -305,7 +277,7 @@ def _load_nvrtc():
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)

# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)


@functools.lru_cache(maxsize=None)
Expand Down