diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 32529c86..e30528ce 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -34,10 +34,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master - - name: Set up Python 3.7 + - name: Set up Python 3.9 uses: actions/setup-python@0f07f7f756721ebd886c2462646a35f78a8bc4de # v1.2.4 with: - python-version: 3.7 + python-version: 3.9 - name: Set up JDK 1.8 uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4 with: diff --git a/.github/workflows/ray_nightly_test.yml b/.github/workflows/ray_nightly_test.yml index 0284da22..e49c8c69 100644 --- a/.github/workflows/ray_nightly_test.yml +++ b/.github/workflows/ray_nightly_test.yml @@ -31,8 +31,8 @@ jobs: strategy: matrix: os: [ ubuntu-latest ] - python-version: [3.8, 3.9] - spark-version: [3.1.3, 3.2.4, 3.3.2, 3.4.0] + python-version: [3.8, 3.9, 3.10.14] + spark-version: [3.2.4, 3.3.2, 3.4.0, 3.5.0] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index f3be4121..dabdb07c 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -33,8 +33,8 @@ jobs: strategy: matrix: os: [ ubuntu-latest ] - python-version: [3.8, 3.9] - spark-version: [3.1.3, 3.2.4, 3.3.2, 3.4.0] + python-version: [3.8, 3.9, 3.10.14] + spark-version: [3.2.4, 3.3.2, 3.4.0, 3.5.0] runs-on: ${{ matrix.os }} @@ -82,7 +82,7 @@ jobs: else pip install torch fi - pip install pyarrow==6.0.1 ray[train] pytest koalas tensorflow==2.13.1 tabulate grpcio-tools wget + pip install pyarrow==6.0.1 ray[train] pytest tensorflow==2.13.1 tabulate grpcio-tools wget pip install "xgboost_ray[default]<=0.1.13" pip install torchmetrics - name: Cache Maven diff --git a/.github/workflows/raydp_nightly.yml b/.github/workflows/raydp_nightly.yml index 44304a20..ac873ed4 100644 --- a/.github/workflows/raydp_nightly.yml +++ b/.github/workflows/raydp_nightly.yml @@ -34,10 +34,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master - - name: Set up Python 3.7 + - name: Set up Python 3.9 uses: actions/setup-python@0f07f7f756721ebd886c2462646a35f78a8bc4de # v1.2.4 with: - python-version: 3.7 + python-version: 3.9 - name: Set up JDK 1.8 uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4 with: diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index 0afcb204..949618e5 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -17,21 +17,19 @@ package org.apache.spark.sql.raydp - +import com.intel.raydp.shims.SparkShimLoader +import io.ray.api.{ActorHandle, ObjectRef, PyActorHandle, Ray} +import io.ray.runtime.AbstractRayRuntime import java.io.ByteArrayOutputStream import java.util.{List, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} import java.util.function.{Function => JFunction} - -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import io.ray.api.{ActorHandle, ObjectRef, PyActorHandle, Ray} -import io.ray.runtime.AbstractRayRuntime import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.arrow.vector.types.pojo.Schema +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.{RayDPException, SparkContext} import org.apache.spark.deploy.raydp._ @@ -105,7 +103,7 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable { Iterator(iter) } - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val arrowSchema = SparkShimLoader.getSparkShims.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"ray object store writer", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) @@ -217,7 +215,7 @@ object ObjectStoreWriter { def toArrowSchema(df: DataFrame): Schema = { val conf = df.queryExecution.sparkSession.sessionState.conf val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE) - ArrowUtils.toArrowSchema(df.schema, timeZoneId) + SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId) } def fromSparkRDD(df: DataFrame, storageLevel: StorageLevel): Array[Array[Byte]] = { diff --git a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 01d42209..2ca83522 100644 --- a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -17,9 +17,11 @@ package com.intel.raydp.shims +import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.java.JavaRDD import org.apache.spark.executor.RayDPExecutorBackendFactory +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SparkSession} sealed abstract class ShimDescriptor @@ -36,4 +38,6 @@ trait SparkShims { def getExecutorBackendFactory(): RayDPExecutorBackendFactory def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext + + def toArrowSchema(schema : StructType, timeZoneId : String) : Schema } diff --git a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 3150599f..6ea817db 100644 --- a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark322._ import org.apache.spark.spark322.TaskContextUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark322.SparkSqlUtils - import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.types.StructType class Spark322Shims extends SparkShims { override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR @@ -44,4 +45,8 @@ class Spark322Shims extends SparkShims { override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { TaskContextUtils.getDummyTaskContext(partitionId, env) } + + override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + } } diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index ad0dbe7c..be9b409c 100644 --- a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.spark322 +import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.api.java.JavaRDD -import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils object SparkSqlUtils { def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = { ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session)) } + + def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + } } diff --git a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala index fdc1af95..4f1a50b5 100644 --- a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark330._ import org.apache.spark.spark330.TaskContextUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark330.SparkSqlUtils - import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.types.StructType class Spark330Shims extends SparkShims { override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR @@ -44,4 +45,8 @@ class Spark330Shims extends SparkShims { override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { TaskContextUtils.getDummyTaskContext(partitionId, env) } + + override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + } } diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index 5f3fb148..162371ad 100644 --- a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.spark330 +import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.api.java.JavaRDD -import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils object SparkSqlUtils { def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = { ArrowConverters.toDataFrame(rdd, schema, session) } + + def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + } } diff --git a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala index 28e0c7ed..a878c6a3 100644 --- a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala +++ b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala @@ -23,7 +23,9 @@ object SparkShimProvider { val SPARK340_DESCRIPTOR = SparkShimDescriptor(3, 4, 0) val SPARK341_DESCRIPTOR = SparkShimDescriptor(3, 4, 1) val SPARK342_DESCRIPTOR = SparkShimDescriptor(3, 4, 2) - val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR") + val SPARK343_DESCRIPTOR = SparkShimDescriptor(3, 4, 3) + val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR", + s"$SPARK343_DESCRIPTOR") val DESCRIPTOR = SPARK341_DESCRIPTOR } diff --git a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 50ce9aba..c444373f 100644 --- a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark340._ import org.apache.spark.spark340.TaskContextUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark340.SparkSqlUtils - import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.types.StructType class Spark340Shims extends SparkShims { override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR @@ -44,4 +45,8 @@ class Spark340Shims extends SparkShims { override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { TaskContextUtils.getDummyTaskContext(partitionId, env) } + + override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + } } diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index 9159ed3a..eb52d8e7 100644 --- a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.spark340 +import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD -import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils object SparkSqlUtils { def toDataFrame( @@ -36,4 +38,8 @@ object SparkSqlUtils { } session.internalCreateDataFrame(rdd.setName("arrow"), schema) } + + def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + } } diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala index bca05fa1..0a2ba58a 100644 --- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala +++ b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala @@ -21,7 +21,8 @@ import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} object SparkShimProvider { val SPARK350_DESCRIPTOR = SparkShimDescriptor(3, 5, 0) - val DESCRIPTOR_STRINGS = Seq(s"$SPARK350_DESCRIPTOR") + val SPARK351_DESCRIPTOR = SparkShimDescriptor(3, 5, 1) + val DESCRIPTOR_STRINGS = Seq(s"$SPARK350_DESCRIPTOR", s"$SPARK351_DESCRIPTOR") val DESCRIPTOR = SPARK350_DESCRIPTOR } diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 558614a0..721d6923 100644 --- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark350._ import org.apache.spark.spark350.TaskContextUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark350.SparkSqlUtils - import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.types.StructType class Spark350Shims extends SparkShims { override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR @@ -44,4 +45,8 @@ class Spark350Shims extends SparkShims { override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { TaskContextUtils.getDummyTaskContext(partitionId, env) } + + override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + } } diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index 3ae4ae80..dfd063f7 100644 --- a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.spark350 +import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD -import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils object SparkSqlUtils { def toDataFrame( @@ -36,4 +38,8 @@ object SparkSqlUtils { } session.internalCreateDataFrame(rdd.setName("arrow"), schema) } + + def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false) + } } diff --git a/python/raydp/spark/interfaces.py b/python/raydp/spark/interfaces.py index c8516b74..b88ec62e 100644 --- a/python/raydp/spark/interfaces.py +++ b/python/raydp/spark/interfaces.py @@ -20,8 +20,8 @@ from raydp.utils import convert_to_spark -DF = Union["pyspark.sql.DataFrame", "koalas.DataFrame"] -OPTIONAL_DF = Union[Optional["pyspark.sql.DataFrame"], Optional["koalas.DataFrame"]] +DF = Union["pyspark.sql.DataFrame", "pyspark.pandas.DataFrame"] +OPTIONAL_DF = Union[Optional["pyspark.sql.DataFrame"], Optional["pyspark.pandas.DataFrame"]] class SparkEstimatorInterface: diff --git a/python/raydp/tests/test_spark_utils.py b/python/raydp/tests/test_spark_utils.py index 3033ace7..86084d6c 100644 --- a/python/raydp/tests/test_spark_utils.py +++ b/python/raydp/tests/test_spark_utils.py @@ -18,7 +18,9 @@ import math import sys -import databricks.koalas as ks +# https://spark.apache.org/docs/latest/api/python/migration_guide/koalas_to_pyspark.html +# import databricks.koalas as ks +import pyspark.pandas as ps import pyspark import pytest @@ -27,13 +29,13 @@ def test_df_type_check(spark_session): spark_df = spark_session.range(0, 10) - koalas_df = ks.range(0, 10) + koalas_df = ps.range(0, 10) assert utils.df_type_check(spark_df) assert utils.df_type_check(koalas_df) other_df = "df" error_msg = (f"The type: {type(other_df)} is not supported, only support " + - "pyspark.sql.DataFrame and databricks.koalas.DataFrame") + "pyspark.sql.DataFrame and pyspark.pandas.DataFrame") with pytest.raises(Exception) as exinfo: utils.df_type_check(other_df) assert str(exinfo.value) == error_msg @@ -45,15 +47,15 @@ def test_convert_to_spark(spark_session): assert is_spark_df assert spark_df is converted - koalas_df = ks.range(0, 10) - converted, is_spark_df = utils.convert_to_spark(koalas_df) + pandas_on_spark_df = ps.range(0, 10) + converted, is_spark_df = utils.convert_to_spark(pandas_on_spark_df) assert not is_spark_df assert isinstance(converted, pyspark.sql.DataFrame) assert converted.count() == 10 other_df = "df" error_msg = (f"The type: {type(other_df)} is not supported, only support " + - "pyspark.sql.DataFrame and databricks.koalas.DataFrame") + "pyspark.sql.DataFrame and pyspark.pandas.DataFrame") with pytest.raises(Exception) as exinfo: utils.df_type_check(other_df) assert str(exinfo.value) == error_msg @@ -64,10 +66,10 @@ def test_random_split(spark_session): splits = utils.random_split(spark_df, [0.7, 0.3]) assert len(splits) == 2 - koalas_df = ks.range(0, 10) + koalas_df = ps.range(0, 10) splits = utils.random_split(koalas_df, [0.7, 0.3]) - assert isinstance(splits[0], ks.DataFrame) - assert isinstance(splits[1], ks.DataFrame) + assert isinstance(splits[0], ps.DataFrame) + assert isinstance(splits[1], ps.DataFrame) assert len(splits) == 2 diff --git a/python/raydp/tests/test_torch.py b/python/raydp/tests/test_torch.py index 183d0594..73fe6238 100644 --- a/python/raydp/tests/test_torch.py +++ b/python/raydp/tests/test_torch.py @@ -21,7 +21,9 @@ import shutil import torch -import databricks.koalas as ks +# https://spark.apache.org/docs/latest/api/python/migration_guide/koalas_to_pyspark.html +# import databricks.koalas as ks +import pyspark.pandas as ps from raydp.torch import TorchEstimator from raydp.utils import random_split @@ -32,7 +34,7 @@ def test_torch_estimator(spark_on_ray_small, use_fs_directory): spark = spark_on_ray_small # calculate z = 3 * x + 4 * y + 5 - df: ks.DataFrame = ks.range(0, 100000) + df: ps.DataFrame = ps.range(0, 100000) df["x"] = df["id"] + 100 df["y"] = df["id"] + 1000 df["z"] = df["x"] * 3 + df["y"] * 4 + 5 diff --git a/python/raydp/utils.py b/python/raydp/utils.py index 328c4c8b..c71cd2e3 100644 --- a/python/raydp/utils.py +++ b/python/raydp/utils.py @@ -78,12 +78,12 @@ def random_split(df, weights, seed=None): if is_spark_df: return splits else: - # convert back to koalas DataFrame - import databricks.koalas as ks # pylint: disable=C0415 - return [ks.DataFrame(split) for split in splits] + # convert back to pandas on Spark DataFrame + import pyspark.pandas as ps # pylint: disable=C0415 + return [ps.DataFrame(split) for split in splits] -def _df_helper(df, spark_callback, koalas_callback): +def _df_helper(df, spark_callback, spark_pandas_callback): try: import pyspark # pylint: disable=C0415 except Exception: @@ -93,15 +93,15 @@ def _df_helper(df, spark_callback, koalas_callback): return spark_callback(df) try: - import databricks.koalas as ks # pylint: disable=C0415 + import pyspark.pandas as ps # pylint: disable=C0415 except Exception: pass else: - if isinstance(df, ks.DataFrame): - return koalas_callback(df) + if isinstance(df, ps.DataFrame): + return spark_pandas_callback(df) raise Exception(f"The type: {type(df)} is not supported, only support " - "pyspark.sql.DataFrame and databricks.koalas.DataFrame") + "pyspark.sql.DataFrame and pyspark.pandas.DataFrame") def df_type_check(df): diff --git a/python/setup.py b/python/setup.py index b5231785..1a0a4f3c 100644 --- a/python/setup.py +++ b/python/setup.py @@ -95,12 +95,12 @@ def run(self): copy2(SCRIPT_PATH, SCRIPT_TARGET) install_requires = [ - "numpy", + "numpy < 2.0.0", "pandas >= 1.1.4", "psutil", - "pyarrow >= 4.0.1", + "pyarrow >= 4.0.1, <15.0.0", "ray >= 2.1.0", - "pyspark >= 3.1.1, <= 3.5.0", + "pyspark >= 3.1.1, <=3.5.1", "netifaces", "protobuf > 3.19.5, <= 3.20.3" ] @@ -136,9 +136,10 @@ def run(self): python_requires='>=3.6', classifiers=[ 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8'] + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + ] ) finally: rmtree(os.path.join(TEMP_PATH, "jars"))