Skip to content

Commit

Permalink
Test Spark 4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Jul 29, 2024
1 parent 8d81a0e commit d0c9090
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/data/core/matrix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ latest: &latest
matrix:
small: [*max]
full: [*min, *max]
nightly: [*min, *max, *latest]
nightly: [*min, *latest]
6 changes: 4 additions & 2 deletions onetl/_util/scala.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def get_default_scala_version(spark_version: Version) -> Version:
"""
Get default Scala version for specific Spark version
"""
if spark_version.major < 3:
if spark_version.major == 2:
return Version("2.11")
return Version("2.12")
if spark_version.major == 3:
return Version("2.12")
return Version("2.13")
31 changes: 22 additions & 9 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,11 @@ def _execute_on_driver(
statement_args = self._get_statement_args()
jdbc_statement = self._build_statement(statement, statement_type, jdbc_connection, statement_args)

return self._execute_statement(jdbc_statement, statement, options, callback, read_only)
return self._execute_statement(jdbc_connection, jdbc_statement, statement, options, callback, read_only)

def _execute_statement(
self,
jdbc_connection,
jdbc_statement,
statement: str,
options: JDBCFetchOptions | JDBCExecuteOptions,
Expand Down Expand Up @@ -472,7 +473,7 @@ def _execute_statement(
else:
jdbc_statement.executeUpdate(statement)

return callback(jdbc_statement)
return callback(jdbc_connection, jdbc_statement)

@staticmethod
def _build_statement(
Expand Down Expand Up @@ -501,11 +502,11 @@ def _build_statement(

return jdbc_connection.createStatement(*statement_args)

def _statement_to_dataframe(self, jdbc_statement) -> DataFrame:
def _statement_to_dataframe(self, jdbc_connection, jdbc_statement) -> DataFrame:
result_set = jdbc_statement.getResultSet()
return self._resultset_to_dataframe(result_set)
return self._resultset_to_dataframe(jdbc_connection, result_set)

def _statement_to_optional_dataframe(self, jdbc_statement) -> DataFrame | None:
def _statement_to_optional_dataframe(self, jdbc_connection, jdbc_statement) -> DataFrame | None:
"""
Returns ``org.apache.spark.sql.DataFrame`` or ``None``, if ResultSet is does not contain any columns.
Expand All @@ -522,9 +523,9 @@ def _statement_to_optional_dataframe(self, jdbc_statement) -> DataFrame | None:
if not result_column_count:
return None

return self._resultset_to_dataframe(result_set)
return self._resultset_to_dataframe(jdbc_connection, result_set)

def _resultset_to_dataframe(self, result_set) -> DataFrame:
def _resultset_to_dataframe(self, jdbc_connection, result_set) -> DataFrame:
"""
Converts ``java.sql.ResultSet`` to ``org.apache.spark.sql.DataFrame`` using Spark's internal methods.
Expand All @@ -545,13 +546,25 @@ def _resultset_to_dataframe(self, result_set) -> DataFrame:

java_converters = self.spark._jvm.scala.collection.JavaConverters # type: ignore

if get_spark_version(self.spark) >= Version("3.4"):
if get_spark_version(self.spark) >= Version("4.0"):
result_schema = jdbc_utils.getSchema(
jdbc_connection,
result_set,
jdbc_dialect,
False, # noqa: WPS425
False, # noqa: WPS425
)
elif get_spark_version(self.spark) >= Version("3.4"):
# https://github.com/apache/spark/commit/2349175e1b81b0a61e1ed90c2d051c01cf78de9b
result_schema = jdbc_utils.getSchema(result_set, jdbc_dialect, False, False) # noqa: WPS425
else:
result_schema = jdbc_utils.getSchema(result_set, jdbc_dialect, False) # noqa: WPS425

result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema)
if get_spark_version(self.spark) >= Version("4.0"):
result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema, jdbc_dialect)
else:
result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema)

result_list = java_converters.seqAsJavaListConverter(result_iterator.toSeq()).asJava()
jdf = self.spark._jsparkSession.createDataFrame(result_list, result_schema) # type: ignore

Expand Down
7 changes: 6 additions & 1 deletion onetl/connection/db_connection/kafka/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,13 @@ def get_packages(
raise ValueError(f"Spark version must be at least 2.4, got {spark_ver}")

scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver)

if spark_ver.major < 4:
version = spark_ver.format("{0}.{1}.{2}")
else:
version = "4.0.0-preview1"
return [
f"org.apache.spark:spark-sql-kafka-0-10_{scala_ver.format('{0}.{1}')}:{spark_ver.format('{0}.{1}.{2}')}",
f"org.apache.spark:spark-sql-kafka-0-10_{scala_ver.format('{0}.{1}')}:{version}",
]

def __enter__(self):
Expand Down
7 changes: 6 additions & 1 deletion onetl/connection/file_df_connection/spark_s3/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,14 @@ def get_packages(
# https://issues.apache.org/jira/browse/SPARK-23977
raise ValueError(f"Spark version must be at least 3.x, got {spark_ver}")

if spark_ver.major < 4:
version = spark_ver.format("{0}.{1}.{2}")
else:
version = "4.0.0-preview1"

scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver)
# https://mvnrepository.com/artifact/org.apache.spark/spark-hadoop-cloud
return [f"org.apache.spark:spark-hadoop-cloud_{scala_ver.format('{0}.{1}')}:{spark_ver.format('{0}.{1}.{2}')}"]
return [f"org.apache.spark:spark-hadoop-cloud_{scala_ver.format('{0}.{1}')}:{version}"]

@slot
def path_from_string(self, path: os.PathLike | str) -> RemotePath:
Expand Down
7 changes: 6 additions & 1 deletion onetl/file/format/avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ def get_packages(
if scala_ver < Version("2.11"):
raise ValueError(f"Scala version should be at least 2.11, got {scala_ver.format('{0}.{1}')}")

return [f"org.apache.spark:spark-avro_{scala_ver.format('{0}.{1}')}:{spark_ver.format('{0}.{1}.{2}')}"]
if spark_ver.major < 4:
version = spark_ver.format("{0}.{1}.{2}")
else:
version = "4.0.0-preview1"

return [f"org.apache.spark:spark-avro_{scala_ver.format('{0}.{1}')}:{version}"]

@slot
def check_if_supported(self, spark: SparkSession) -> None:
Expand Down
18 changes: 15 additions & 3 deletions onetl/file/format/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def get_packages( # noqa: WPS231
)
"""
spark_ver = Version(spark_version)
if spark_ver.major >= 4:
return []

if package_version:
version = Version(package_version).min_digits(3)
Expand All @@ -202,7 +205,6 @@ def get_packages( # noqa: WPS231
else:
version = Version("0.18.0").min_digits(3)

spark_ver = Version(spark_version)
scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver)

# Ensure compatibility with Spark and Scala versions
Expand All @@ -216,8 +218,11 @@ def get_packages( # noqa: WPS231

@slot
def check_if_supported(self, spark: SparkSession) -> None:
java_class = "com.databricks.spark.xml.XmlReader"
version = get_spark_version(spark)
if version.major >= 4:
return

java_class = "com.databricks.spark.xml.XmlReader"
try:
try_import_java_class(spark, java_class)
except Exception as e:
Expand Down Expand Up @@ -332,19 +337,26 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column:
| |-- name: string (nullable = true)
| |-- age: integer (nullable = true)
"""
from pyspark import __version__ as spark_version
from pyspark.sql import Column, SparkSession # noqa: WPS442

spark = SparkSession._instantiatedSession # noqa: WPS437
self.check_if_supported(spark)

from pyspark.sql.column import _to_java_column # noqa: WPS450
from pyspark.sql.functions import col

if isinstance(column, Column):
column_name, column = column._jc.toString(), column.cast("string") # noqa: WPS437
else:
column_name, column = column, col(column).cast("string")

if spark_version > "4":
from pyspark.sql.functions import from_xml # noqa: WPS450

return from_xml(column, schema, self.dict()).alias(column_name)

from pyspark.sql.column import _to_java_column # noqa: WPS450

java_column = _to_java_column(column)
java_schema = spark._jsparkSession.parseDataType(schema.json()) # noqa: WPS437
scala_options = spark._jvm.org.apache.spark.api.python.PythonUtils.toScalaMap( # noqa: WPS219, WPS437
Expand Down
2 changes: 1 addition & 1 deletion requirements/tests/spark-latest.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy>=1.16
pandas>=1.0
pyarrow>=1.0
pyspark
pyspark==4.0.0.dev1
sqlalchemy
2 changes: 1 addition & 1 deletion tests/fixtures/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def maven_packages(request):
# There is no MongoDB connector for Spark less than 3.2
packages.extend(MongoDB.get_packages(spark_version=str(pyspark_version)))

if "excel" in markers:
if "excel" in markers and pyspark_version < Version("4.0"):
# There is no Excel files support for Spark less than 3.2
packages.extend(Excel.get_packages(spark_version=str(pyspark_version)))

Expand Down

0 comments on commit d0c9090

Please sign in to comment.