Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rocm6.4_internal_testing] Enable wheels (#1884) #1907

Merged
46 changes: 0 additions & 46 deletions .circleci/scripts/binary_populate_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,52 +71,6 @@ fi

export PYTORCH_BUILD_NUMBER=1

# Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)

# Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version != '3.13t'"
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then
TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
# Only linux Python <= 3.13, not 3.13t are supported wheels for triton
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt)
TRITON_REQUIREMENT="pytorch-triton==${TRITON_VERSION}+git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
fi
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}"
fi

# This part is done in the builder scripts so commenting the duplicate code
: <<'BLOCK_COMMENT'
# Set triton via PYTORCH_EXTRA_INSTALL_REQUIREMENTS for triton rocm package
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" ]]; then
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt)
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
fi
if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${TRITON_REQUIREMENT}"
else
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}"
fi
fi
BLOCK_COMMENT

# Set triton via PYTORCH_EXTRA_INSTALL_REQUIREMENTS for triton xpu package
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*xpu.* && $(uname) == "Linux" ]]; then
TRITON_REQUIREMENT="pytorch-triton-xpu==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-xpu.txt)
TRITON_REQUIREMENT="pytorch-triton-xpu==${TRITON_VERSION}+git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
fi
if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${TRITON_REQUIREMENT}"
else
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}"
fi
fi

USE_GLOO_WITH_OPENSSL="ON"
if [[ "$GPU_ARCH_TYPE" =~ .*aarch64.* ]]; then
USE_GLOO_WITH_OPENSSL="OFF"
Expand Down
7 changes: 3 additions & 4 deletions .github/scripts/build_triton_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_rocm_version() -> str:
rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
if not os.path.isfile(rocm_version_h):
rocm_version_h = f"{rocm_path}/include/rocm_version.h"

# The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
if os.path.isfile(rocm_version_h):
RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
Expand Down Expand Up @@ -94,7 +94,7 @@ def build_triton(
# Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
# while release build should only include the version, i.e. 2.1.0
rocm_version = get_rocm_version()
version_suffix = f"+rocm{rocm_version}_{commit_hash[:10]}"
version_suffix = f"+rocm{rocm_version}.git{commit_hash[:8]}"
version += version_suffix

with TemporaryDirectory() as tmpdir:
Expand Down Expand Up @@ -168,6 +168,7 @@ def build_triton(

# change built wheel name and version
env["TRITON_WHEEL_NAME"] = triton_pkg_name
env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
if with_clang_ldd:
env["TRITON_BUILD_WITH_CLANG_LLD"] = "1"

Expand All @@ -183,8 +184,6 @@ def build_triton(
cwd=triton_basedir,
shell=True,
)
cur_rocm_ver = get_rocm_version()
check_call(["scripts/amd/setup_rocm_libs.sh", cur_rocm_ver], cwd=triton_basedir)
print("ROCm libraries setup for triton installation...")

check_call(
Expand Down
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ set(CMAKE_C_STANDARD
# ---[ Utils
include(cmake/public/utils.cmake)

# --- [ Check that minimal gcc version is 9.3+
if(CMAKE_COMPILER_IS_GNUCXX AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.3)
# --- [ Check that minimal gcc version is 9.2+
if(CMAKE_COMPILER_IS_GNUCXX AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.2)
message(
FATAL_ERROR
"GCC-9.3 or newer is required to compile PyTorch, but found ${CMAKE_CXX_COMPILER_VERSION}"
"GCC-9.2 or newer is required to compile PyTorch, but found ${CMAKE_CXX_COMPILER_VERSION}"
)
endif()

Expand Down