diff --git a/rapids/rapids.sh b/rapids/rapids.sh index 877c9ad14..282bad848 100644 --- a/rapids/rapids.sh +++ b/rapids/rapids.sh @@ -8,20 +8,18 @@ function get_metadata_attribute() { /usr/share/google/get_metadata_value "attributes/${attribute_name}" || echo -n "${default_value}" } -readonly DEFAULT_DASK_RAPIDS_VERSION="22.04" +readonly DEFAULT_DASK_RAPIDS_VERSION="22.06" readonly RAPIDS_VERSION=$(get_metadata_attribute 'rapids-version' ${DEFAULT_DASK_RAPIDS_VERSION}) readonly SPARK_VERSION_ENV=$(spark-submit --version 2>&1 | sed -n 's/.*version[[:blank:]]\+\([0-9]\+\.[0-9]\).*/\1/p' | head -n1) readonly DEFAULT_SPARK_RAPIDS_VERSION="22.10.0" -if [[ "${SPARK_VERSION_ENV}" == "3"* ]]; then +if [[ "${SPARK_VERSION_ENV%%.*}" == "3" ]]; then readonly DEFAULT_CUDA_VERSION="11.5" readonly DEFAULT_CUDF_VERSION="22.10.0" readonly DEFAULT_XGBOOST_VERSION="1.6.2" readonly DEFAULT_XGBOOST_GPU_SUB_VERSION="0.3.0" - # TODO: uncomment when Spark 3.1 jars will be released - RAPIDS work with Spark 3.1, this is just for Maven URL - # readonly SPARK_VERSION="${SPARK_VERSION_ENV}" - readonly SPARK_VERSION="3.0" + readonly SPARK_VERSION="${SPARK_VERSION_ENV}" else readonly DEFAULT_CUDA_VERSION="10.1" readonly DEFAULT_CUDF_VERSION="0.9.2" @@ -66,8 +64,8 @@ function execute_with_retries() { function install_dask_rapids() { # Install RAPIDS, cudatoolkit - mamba install -y --no-channel-priority -c 'conda-forge' -c 'nvidia' -c 'rapidsai' \ - "cudatoolkit=${CUDA_VERSION}" "rapids=${RAPIDS_VERSION}" + mamba install -n 'dask-rapids' -y --no-channel-priority -c 'conda-forge' -c 'nvidia' -c 'rapidsai' \ + "cudatoolkit=${CUDA_VERSION}" "pandas<1.5" "rapids=${RAPIDS_VERSION}" "python=3.9" } function install_spark_rapids() {