Skip to content

Commit

Permalink
improved coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
dannymeijer committed Nov 8, 2024
1 parent 8997b91 commit 90e9a77
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 15 deletions.
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ init: hatch-install
.PHONY: sync ## hatch - Update dependencies if you changed project dependencies in pyproject.toml
.PHONY: update ## hatch - alias for sync (if you are used to poetry, thi is similar to running `poetry update`)
sync:
@hatch run dev:uv sync --all-extras --dev
@hatch run dev:uv sync --all-extras
update: sync

# Code Quality
Expand Down
32 changes: 23 additions & 9 deletions src/koheesio/spark/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def get_spark_minor_version() -> float:
def check_if_pyspark_connect_is_supported() -> bool:
"""Check if the current version of PySpark supports the connect module"""
result = False
module_name: str = "pyspark"

if SPARK_MINOR_VERSION >= 3.5:
try:
importlib.import_module(f"{module_name}.sql.connect")
importlib.import_module("pyspark.sql.connect")
from pyspark.sql.connect.column import Column

_col: Column
Expand Down Expand Up @@ -119,9 +119,13 @@ def check_if_pyspark_connect_is_supported() -> bool:
ParseException = (CapturedParseException, ConnectParseException)
DataType = Union[SqlDataType, ConnectDataType]
DataFrameReader = Union[sql.readwriter.DataFrameReader, DataFrameReader]
DataStreamReader = Union[sql.streaming.readwriter.DataStreamReader, DataStreamReader]
DataStreamReader = Union[
sql.streaming.readwriter.DataStreamReader, DataStreamReader
]
DataFrameWriter = Union[sql.readwriter.DataFrameWriter, DataFrameWriter]
DataStreamWriter = Union[sql.streaming.readwriter.DataStreamWriter, DataStreamWriter]
DataStreamWriter = Union[
sql.streaming.readwriter.DataStreamWriter, DataStreamWriter
]
StreamingQuery = StreamingQuery
else:
"""Import the regular PySpark modules if the current version of PySpark does not support the connect module"""
Expand Down Expand Up @@ -156,8 +160,9 @@ def check_if_pyspark_connect_is_supported() -> bool:

def get_active_session() -> SparkSession: # type: ignore
"""Get the active Spark session"""
print("Entering get_active_session")
if check_if_pyspark_connect_is_supported():
from pyspark.sql.connect.session import SparkSession as _ConnectSparkSession
from pyspark.sql.connect import SparkSession as _ConnectSparkSession

session = _ConnectSparkSession.getActiveSession() or sql.SparkSession.getActiveSession() # type: ignore
else:
Expand Down Expand Up @@ -292,14 +297,18 @@ def spark_data_type_is_array(data_type: DataType) -> bool: # type: ignore

def spark_data_type_is_numeric(data_type: DataType) -> bool: # type: ignore
"""Check if the column's dataType is of type ArrayType"""
return isinstance(data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType))
return isinstance(
data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType)
)


def schema_struct_to_schema_str(schema: StructType) -> str:
"""Converts a StructType to a schema str"""
if not schema:
return ""
return ",\n".join([f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields])
return ",\n".join(
[f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields]
)


def import_pandas_based_on_pyspark_version() -> ModuleType:
Expand All @@ -314,7 +323,9 @@ def import_pandas_based_on_pyspark_version() -> ModuleType:
pyspark_version = get_spark_minor_version()
pandas_version = pd.__version__

if (pyspark_version < 3.4 and pandas_version >= "2") or (pyspark_version >= 3.4 and pandas_version < "2"):
if (pyspark_version < 3.4 and pandas_version >= "2") or (
pyspark_version >= 3.4 and pandas_version < "2"
):
raise ImportError(
f"For PySpark {pyspark_version}, "
f"please install Pandas version {'< 2' if pyspark_version < 3.4 else '>= 2'}"
Expand Down Expand Up @@ -379,7 +390,10 @@ def get_column_name(col: Column) -> str: # type: ignore
# In case of a 'regular' Column object, we can directly access the name attribute through the _jc attribute
# noinspection PyProtectedMember
name = col._jc.toString() # type: ignore[operator]
elif any(cls.__module__ == "pyspark.sql.connect.column" for cls in inspect.getmro(col.__class__)):
elif any(
cls.__module__ == "pyspark.sql.connect.column"
for cls in inspect.getmro(col.__class__)
):
# noinspection PyProtectedMember
name = col._expr.name()
else:
Expand Down
98 changes: 93 additions & 5 deletions tests/spark/test_spark_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os import environ
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

Expand All @@ -12,10 +12,89 @@
schema_struct_to_schema_str,
show_string,
)
from koheesio.spark.utils.common import (
check_if_pyspark_connect_is_supported,
get_active_session,
get_spark_minor_version,
)


class TestGetActiveSession:
def test_unhappy_get_active_session_spark_connect(self):
"""Test that get_active_session raises an error when no active session is found when using spark connect."""
with (
# ensure that we are forcing the code to think that we are using spark connect
patch(
"koheesio.spark.utils.common.check_if_pyspark_connect_is_supported",
return_value=True,
),
# make sure that spark session is not found
patch("pyspark.sql.SparkSession.getActiveSession", return_value=None),
):
session = MagicMock(
SparkSession=MagicMock(getActiveSession=MagicMock(return_value=None))
)
with patch.dict("sys.modules", {"pyspark.sql.connect": session}):
with pytest.raises(
RuntimeError,
match="No active Spark session found. Please create a Spark session before using module "
"connect_utils. Or perform local import of the module.",
):
get_active_session()

def test_unhappy_get_active_session(self):
"""Test that get_active_session raises an error when no active session is found."""
with (
patch(
"koheesio.spark.utils.common.check_if_pyspark_connect_is_supported",
return_value=False,
),
patch("pyspark.sql.SparkSession.getActiveSession", return_value=None),
):
with pytest.raises(
RuntimeError,
match="No active Spark session found. Please create a Spark session before using module connect_utils. "
"Or perform local import of the module.",
):
get_active_session()

def test_get_active_session_with_spark(self, spark):
"""Test get_active_session when an active session is found"""
session = get_active_session()
assert session is not None


class TestCheckIfPysparkConnectIsSupported:
def test_if_pyspark_connect_is_not_supported(self):
"""Test that check_if_pyspark_connect_is_supported returns False when pyspark connect is not supported."""
with patch.dict("sys.modules", {"pyspark.sql.connect": None}):
assert check_if_pyspark_connect_is_supported() is False

def test_check_if_pyspark_connect_is_supported(self):
"""Test that check_if_pyspark_connect_is_supported returns True when pyspark connect is supported."""
with (
patch("koheesio.spark.utils.common.SPARK_MINOR_VERSION", 3.5),
patch.dict(
"sys.modules",
{
"pyspark.sql.connect.column": MagicMock(Column=MagicMock()),
"pyspark.sql.connect": MagicMock(),
},
),
):
assert check_if_pyspark_connect_is_supported() is True


def test_get_spark_minor_version():
"""Test that get_spark_minor_version returns the correctly formatted version."""
with patch("koheesio.spark.utils.common.spark_version", "9.9.42"):
assert get_spark_minor_version() == 9.9


def test_schema_struct_to_schema_str():
struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())])
struct_schema = StructType(
[StructField("a", StringType()), StructField("b", StringType())]
)
val = schema_struct_to_schema_str(struct_schema)
assert val == "a STRING,\nb STRING"
assert schema_struct_to_schema_str(None) == ""
Expand All @@ -40,12 +119,21 @@ def test_on_databricks(env_var_value, expected_result):
(3.3, "1.2.3", None), # PySpark 3.3, pandas < 2, should not raise an error
(3.4, "2.3.4", None), # PySpark not 3.3, pandas >= 2, should not raise an error
(3.3, "2.3.4", ImportError), # PySpark 3.3, pandas >= 2, should raise an error
(3.4, "1.2.3", ImportError), # PySpark not 3.3, pandas < 2, should raise an error
(
3.4,
"1.2.3",
ImportError,
), # PySpark not 3.3, pandas < 2, should raise an error
],
)
def test_import_pandas_based_on_pyspark_version(spark_version, pandas_version, expected_error):
def test_import_pandas_based_on_pyspark_version(
spark_version, pandas_version, expected_error
):
with (
patch("koheesio.spark.utils.common.get_spark_minor_version", return_value=spark_version),
patch(
"koheesio.spark.utils.common.get_spark_minor_version",
return_value=spark_version,
),
patch("pandas.__version__", new=pandas_version),
):
if expected_error:
Expand Down

0 comments on commit 90e9a77

Please sign in to comment.