diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c19a130..5aae361 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,15 +36,31 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.head_ref }} + - name: Fetch main branch + run: git fetch origin main:main - name: Check changes id: check run: | - echo "python_changed=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | grep '\.py$')" >> "$GITHUB_OUTPUT" - echo "toml_changed=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | grep '\.toml$')" >> "$GITHUB_OUTPUT" + # Set the base reference for the git diff + BASE_REF=${{ github.event.pull_request.base.ref || 'main' }} + + # Check for changes in this PR / commit + git_diff_output=$(git diff --name-only $BASE_REF ${{ github.event.after }}) + + # Count the number of changes to Python and TOML files + python_changed=$(echo "$git_diff_output" | grep '\.py$' | wc -l) + toml_changed=$(echo "$git_diff_output" | grep '\.toml$' | wc -l) + + # Write the changes to the GITHUB_OUTPUT environment file + echo "python_changed=$python_changed" >> $GITHUB_OUTPUT + echo "toml_changed=$toml_changed" >> $GITHUB_OUTPUT tests: needs: check_changes - if: needs.check_changes.outputs.python_changed != '' || needs.check_changes.outputs.toml_changed != '' || github.event_name == 'workflow_dispatch' + if: needs.check_changes.outputs.python_changed > 0 || needs.check_changes.outputs.toml_changed > 0 || github.event_name == 'workflow_dispatch' name: Python ${{ matrix.python-version }} with PySpark ${{ matrix.pyspark-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} runs-on: ${{ matrix.os }} diff --git a/pyproject.toml b/pyproject.toml index 6abf886..63f988f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ se = ["spark-expectations>=2.1.0"] # SFTP dependencies in to_csv line_iterator sftp = ["paramiko>=2.6.0"] delta = ["delta-spark>=2.2"] +excel = ["openpyxl>=3.0.0"] dev = ["black", "isort", "ruff", "mypy", "pylint", "colorama", "types-PyYAML"] test = [ "chispa", @@ -175,6 +176,7 @@ features = [ "pyspark", "sftp", "delta", + "excel", "se", "box", "dev", @@ -184,7 +186,6 @@ features = [ [tool.hatch.envs.default.scripts] # TODO: add scripts section based on Makefile # TODO: add bandit -# TODO: move scripts from linting and style here # Code Quality commands black-check = "black --check --diff ." black-fmt = "black ." @@ -192,8 +193,8 @@ isort-check = "isort . --check --diff --color" isort-fmt = "isort ." ruff-check = "ruff check ." ruff-fmt = "ruff check . --fix" -mypy-check = "mypy koheesio" -pylint-check = "pylint --output-format=colorized -d W0511 koheesio" +mypy-check = "mypy src" +pylint-check = "pylint --output-format=colorized -d W0511 src" check = [ "- black-check", "- isort-check", @@ -213,15 +214,7 @@ non-spark-tests = "test -m \"not spark\"" # scripts.run = "- log-versions && pytest tests/ {env:HATCH_TEST_ARGS:} {args}" # run ="echo {args}" # run = "- pytest tests/ {env:HATCH_TEST_ARGS:} {args}" -# run-cov = "coverage run -m pytest{env:HATCH_TEST_ARGS:} {args}" -# cov-combine = "coverage combine" -# cov-report = "coverage report" # log-versions = "python --version && {env:HATCH_UV} pip freeze | grep pyspark" -# -# -# -# coverage = "- pytest tests/ {env:HATCH_TEST_ARGS:} {args} --cov=koheesio --cov-report=html --cov-report=term-missing --cov-fail-under=90" -# cov = "coverage" ### ~~~~~~~~~~~~~~~~~~~~~ ### @@ -249,6 +242,7 @@ features = [ "pyspark", "sftp", "delta", + "excel", "dev", "test", ] @@ -284,6 +278,9 @@ matrix.version.extra-dependencies = [ { value = "spark-expectations>=2.1.0", if = [ "pyspark33", ] }, + { value = "pandas<2", if = [ + "pyspark33", + ] }, { value = "pyspark>=3.4,<3.5", if = [ "pyspark34", ] }, @@ -400,6 +397,7 @@ features = [ "se", "sftp", "delta", + "excel", "dev", "test", "docs", diff --git a/src/koheesio/integrations/spark/dq/spark_expectations.py b/src/koheesio/integrations/spark/dq/spark_expectations.py index 325ccaf..8766a8e 100644 --- a/src/koheesio/integrations/spark/dq/spark_expectations.py +++ b/src/koheesio/integrations/spark/dq/spark_expectations.py @@ -4,15 +4,17 @@ from typing import Any, Dict, Optional, Union -import pyspark -from pydantic import Field -from pyspark.sql import DataFrame from spark_expectations.config.user_config import Constants as user_config from spark_expectations.core.expectations import ( SparkExpectations, WrappedDataFrameWriter, ) +from pydantic import Field + +import pyspark +from pyspark.sql import DataFrame + from koheesio.spark.transformations import Transformation from koheesio.spark.writers import BatchOutputMode diff --git a/src/koheesio/models/reader.py b/src/koheesio/models/reader.py new file mode 100644 index 0000000..c122794 --- /dev/null +++ b/src/koheesio/models/reader.py @@ -0,0 +1,50 @@ +""" +Module for the BaseReader class +""" + +from typing import Optional, TypeVar +from abc import ABC, abstractmethod + +from koheesio import Step + +# Define a type variable that can be any type of DataFrame +DataFrameType = TypeVar("DataFrameType") + + +class BaseReader(Step, ABC): + """Base class for all Readers + + A Reader is a Step that reads data from a source based on the input parameters + and stores the result in self.output.df (DataFrame). + + When implementing a Reader, the execute() method should be implemented. + The execute() method should read from the source and store the result in self.output.df. + + The Reader class implements a standard read() method that calls the execute() method and returns the result. This + method can be used to read data from a Reader without having to call the execute() method directly. Read method + does not need to be implemented in the child class. + + The Reader class also implements a shorthand for accessing the output Dataframe through the df-property. If the + output.df is None, .execute() will be run first. + """ + + @property + def df(self) -> Optional[DataFrameType]: + """Shorthand for accessing self.output.df + If the output.df is None, .execute() will be run first + """ + if not self.output.df: + self.execute() + return self.output.df + + @abstractmethod + def execute(self) -> Step.Output: + """Execute on a Reader should handle self.output.df (output) as a minimum + Read from whichever source -> store result in self.output.df + """ + pass + + def read(self) -> DataFrameType: + """Read from a Reader without having to call the execute() method directly""" + self.execute() + return self.output.df diff --git a/src/koheesio/pandas/__init__.py b/src/koheesio/pandas/__init__.py new file mode 100644 index 0000000..a9d324a --- /dev/null +++ b/src/koheesio/pandas/__init__.py @@ -0,0 +1,27 @@ +"""Base class for a Pandas step + +Extends the Step class with Pandas DataFrame support. The following: +- Pandas steps are expected to return a Pandas DataFrame as output. +""" + +from typing import Optional +from abc import ABC + +from koheesio import Step, StepOutput +from koheesio.models import Field +from koheesio.spark.utils import import_pandas_based_on_pyspark_version + +pandas = import_pandas_based_on_pyspark_version() + + +class PandasStep(Step, ABC): + """Base class for a Pandas step + + Extends the Step class with Pandas DataFrame support. The following: + - Pandas steps are expected to return a Pandas DataFrame as output. + """ + + class Output(StepOutput): + """Output class for PandasStep""" + + df: Optional[pandas.DataFrame] = Field(default=None, description="The Pandas DataFrame") diff --git a/src/koheesio/pandas/readers/__init__.py b/src/koheesio/pandas/readers/__init__.py new file mode 100644 index 0000000..933561a --- /dev/null +++ b/src/koheesio/pandas/readers/__init__.py @@ -0,0 +1,34 @@ +""" +Base class for all Readers +""" + +from abc import ABC, abstractmethod + +from koheesio.models.reader import BaseReader +from koheesio.pandas import PandasStep + + +class Reader(BaseReader, PandasStep, ABC): + """Base class for all Readers + + A Reader is a Step that reads data from a source based on the input parameters + and stores the result in self.output.df (DataFrame). + + When implementing a Reader, the execute() method should be implemented. + The execute() method should read from the source and store the result in self.output.df. + + The Reader class implements a standard read() method that calls the execute() method and returns the result. This + method can be used to read data from a Reader without having to call the execute() method directly. Read method + does not need to be implemented in the child class. + + The Reader class also implements a shorthand for accessing the output Dataframe through the df-property. If the + output.df is None, .execute() will be run first. + """ + + @abstractmethod + def execute(self) -> PandasStep.Output: + """Execute on a Reader should handle self.output.df (output) as a minimum + Read from whichever source -> store result in self.output.df + """ + # self.output.df # output dataframe + ... diff --git a/src/koheesio/pandas/readers/excel.py b/src/koheesio/pandas/readers/excel.py new file mode 100644 index 0000000..5432aed --- /dev/null +++ b/src/koheesio/pandas/readers/excel.py @@ -0,0 +1,50 @@ +""" +Excel reader for Spark + +Note +---- +Ensure the 'excel' extra is installed before using this reader. +Default implementation uses openpyxl as the engine for reading Excel files. +Other implementations can be used by passing the correct keyword arguments to the reader. + +See Also +-------- +- https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_excel.html +- koheesio.pandas.readers.excel.ExcelReader +""" + +from typing import List, Optional, Union +from pathlib import Path + +import pandas as pd + +from koheesio.models import ExtraParamsMixin, Field +from koheesio.pandas.readers import Reader + + +class ExcelReader(Reader, ExtraParamsMixin): + """Read data from an Excel file + + See Also + -------- + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_excel.html + + Attributes + ---------- + path : Union[str, Path] + The path to the Excel file + sheet_name : str + The name of the sheet to read + header : Optional[Union[int, List[int]]] + Row(s) to use as the column names + + Any other keyword arguments will be passed to pd.read_excel. + """ + + path: Union[str, Path] = Field(description="The path to the Excel file") + sheet_name: str = Field(default="Sheet1", description="The name of the sheet to read") + header: Optional[Union[int, List[int]]] = Field(default=0, description="Row(s) to use as the column names") + + def execute(self): + extra_params = self.params or {} + self.output.df = pd.read_excel(self.path, sheet_name=self.sheet_name, header=self.header, **extra_params) diff --git a/src/koheesio/spark/readers/__init__.py b/src/koheesio/spark/readers/__init__.py index 42d0870..ab81a7d 100644 --- a/src/koheesio/spark/readers/__init__.py +++ b/src/koheesio/spark/readers/__init__.py @@ -6,15 +6,13 @@ [reference/concepts/steps/readers](../../../reference/concepts/readers.md) section of the Koheesio documentation. """ -from typing import Optional from abc import ABC, abstractmethod -from pyspark.sql import DataFrame - +from koheesio.models.reader import BaseReader from koheesio.spark import SparkStep -class Reader(SparkStep, ABC): +class Reader(BaseReader, SparkStep, ABC): """Base class for all Readers A Reader is a Step that reads data from a source based on the input parameters @@ -33,24 +31,8 @@ class Reader(SparkStep, ABC): output.df is None, .execute() will be run first. """ - @property - def df(self) -> Optional[DataFrame]: - """Shorthand for accessing self.output.df - If the output.df is None, .execute() will be run first - """ - if not self.output.get("df"): - self.execute() - return self.output.df - @abstractmethod - def execute(self): + def execute(self) -> SparkStep.Output: """Execute on a Reader should handle self.output.df (output) as a minimum Read from whichever source -> store result in self.output.df """ - # self.output.df # output dataframe - ... - - def read(self) -> Optional[DataFrame]: - """Read from a Reader without having to call the execute() method directly""" - self.execute() - return self.output.df diff --git a/src/koheesio/spark/readers/excel.py b/src/koheesio/spark/readers/excel.py new file mode 100644 index 0000000..4b52cc7 --- /dev/null +++ b/src/koheesio/spark/readers/excel.py @@ -0,0 +1,40 @@ +""" +Excel reader for Spark + +Note +---- +Ensure the 'excel' extra is installed before using this reader. +Default implementation uses openpyxl as the engine for reading Excel files. +Other implementations can be used by passing the correct keyword arguments to the reader. + +See Also +-------- +- https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_excel.html +- koheesio.pandas.readers.excel.ExcelReader +""" + +from pyspark.pandas import DataFrame as PandasDataFrame + +from koheesio.pandas.readers.excel import ExcelReader as PandasExcelReader +from koheesio.spark.readers import Reader + + +class ExcelReader(Reader, PandasExcelReader): + """Read data from an Excel file + + This class is a wrapper around the PandasExcelReader class. It reads an Excel file first using pandas, and then + converts the pandas DataFrame to a Spark DataFrame. + + Attributes + ---------- + path: str + The path to the Excel file + sheet_name: str + The name of the sheet to read + header: int + The row to use as the column names + """ + + def execute(self): + pdf: PandasDataFrame = PandasExcelReader.from_step(self).execute().df + self.output.df = self.spark.createDataFrame(pdf) diff --git a/src/koheesio/spark/utils.py b/src/koheesio/spark/utils.py index cffdf75..b382c4b 100644 --- a/src/koheesio/spark/utils.py +++ b/src/koheesio/spark/utils.py @@ -29,6 +29,7 @@ __all__ = [ "SparkDatatype", "get_spark_minor_version", + "import_pandas_based_on_pyspark_version", "on_databricks", "schema_struct_to_schema_str", "spark_data_type_is_array", @@ -177,3 +178,26 @@ def schema_struct_to_schema_str(schema: StructType) -> str: if not schema: return "" return ",\n".join([f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields]) + + +def import_pandas_based_on_pyspark_version(): + """ + This function checks the installed version of PySpark and then tries to import the appropriate version of pandas. + If the correct version of pandas is not installed, it raises an ImportError with a message indicating which version + of pandas should be installed. + """ + try: + import pandas as pd + + pyspark_version = get_spark_minor_version() + pandas_version = pd.__version__ + + if (pyspark_version < 3.4 and pandas_version >= "2") or (pyspark_version >= 3.4 and pandas_version < "2"): + raise ImportError( + f"For PySpark {pyspark_version}, " + f"please install Pandas version {'< 2' if pyspark_version < 3.4 else '>= 2'}" + ) + + return pd + except ImportError as e: + raise ImportError("Pandas module is not installed.") from e diff --git a/src/koheesio/utils.py b/src/koheesio/utils.py index 61bd16a..9547f8c 100644 --- a/src/koheesio/utils.py +++ b/src/koheesio/utils.py @@ -3,7 +3,6 @@ """ import inspect -import os import uuid from typing import Any, Callable, Dict, Optional, Tuple from functools import partial diff --git a/tests/_data/readers/excel_file/dummy.xlsx b/tests/_data/readers/excel_file/dummy.xlsx new file mode 100644 index 0000000..eeef110 Binary files /dev/null and b/tests/_data/readers/excel_file/dummy.xlsx differ diff --git a/tests/conftest.py b/tests/conftest.py index 6a6e4b9..a0090a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,12 @@ import os import time import uuid +from pathlib import Path import pytest from koheesio.logger import LoggingFactory +from koheesio.utils import get_project_root if os.name != "nt": # 'nt' is the name for Windows # force time zone to be UTC @@ -12,6 +14,12 @@ time.tzset() +PROJECT_ROOT = get_project_root() + +TEST_DATA_PATH = Path(PROJECT_ROOT / "tests" / "_data") +DELTA_FILE = Path(TEST_DATA_PATH / "readers" / "delta_file") + + @pytest.fixture(scope="session") def random_uuid(): return str(uuid.uuid4()).replace("-", "_") @@ -20,3 +28,13 @@ def random_uuid(): @pytest.fixture(scope="session") def logger(random_uuid): return LoggingFactory.get_logger(name="conf_test" + random_uuid) + + +@pytest.fixture(scope="session") +def data_path(): + return TEST_DATA_PATH.as_posix() + + +@pytest.fixture(scope="session") +def delta_file(): + return DELTA_FILE.as_posix() diff --git a/tests/pandas/readers/test_pandas_excel.py b/tests/pandas/readers/test_pandas_excel.py new file mode 100644 index 0000000..4e00a23 --- /dev/null +++ b/tests/pandas/readers/test_pandas_excel.py @@ -0,0 +1,28 @@ +from pathlib import Path + +import pandas as pd + +from koheesio.pandas.readers.excel import ExcelReader + + +def test_excel_reader(data_path): + # Define the path to the test Excel file and the expected DataFrame + test_file = Path(data_path) / "readers" / "excel_file" / "dummy.xlsx" + expected_df = pd.DataFrame( + { + "a": ["foo", "so long"], + "b": ["bar", "and thanks"], + "c": ["baz", "for all the fish"], + "d": [None, 42], + "e": pd.to_datetime(["1/1/24", "1/1/24"], format="%m/%d/%y"), + } + ) + + # Initialize the ExcelReader with the path to the test file and the correct sheet name + reader = ExcelReader(path=test_file, sheet_name="sheet_to_select", header=0) + + # Execute the reader + reader.execute() + + # Assert that the output DataFrame is as expected + pd.testing.assert_frame_equal(reader.output.df, expected_df, check_dtype=False) diff --git a/tests/spark/conftest.py b/tests/spark/conftest.py index fc01d6f..f6f40f9 100644 --- a/tests/spark/conftest.py +++ b/tests/spark/conftest.py @@ -32,11 +32,6 @@ from koheesio.logger import LoggingFactory from koheesio.spark.readers.dummy import DummyReader -from koheesio.utils import get_project_root - -PROJECT_ROOT = get_project_root() -TEST_DATA_PATH = Path(PROJECT_ROOT / "tests" / "_data") -DELTA_FILE = Path(TEST_DATA_PATH / "readers" / "delta_file") @pytest.fixture(scope="session") @@ -53,11 +48,6 @@ def checkpoint_folder(tmp_path_factory, random_uuid, logger): yield fldr.as_posix() -@pytest.fixture(scope="session") -def data_path(): - return TEST_DATA_PATH.as_posix() - - @pytest.fixture(scope="session") def spark(warehouse_path, random_uuid): """Spark session fixture with Delta enabled.""" @@ -107,14 +97,14 @@ def set_env_vars(): @pytest.fixture(scope="session", autouse=True) -def setup(spark): +def setup(spark, delta_file): db_name = "klettern" if not spark.catalog.databaseExists(db_name): spark.sql(f"CREATE DATABASE {db_name}") spark.sql(f"USE {db_name}") - setup_test_data(spark=spark) + setup_test_data(spark=spark, delta_file=Path(delta_file)) yield @@ -142,8 +132,8 @@ def sample_df_to_partition(spark): @pytest.fixture -def streaming_dummy_df(spark): - setup_test_data(spark=spark) +def streaming_dummy_df(spark, delta_file): + setup_test_data(spark=spark, delta_file=Path(delta_file)) yield spark.readStream.table("delta_test_table") @@ -198,12 +188,12 @@ def sample_df_with_string_timestamp(spark): return spark.createDataFrame(data, schema) -def setup_test_data(spark): +def setup_test_data(spark, delta_file): """ Sets up test data for the Spark session. Reads a Delta file, creates a temporary view, and populates a Delta table with the view's data. """ - delta_file = DELTA_FILE.absolute().as_posix() + delta_file = delta_file.absolute().as_posix() spark.read.format("delta").load(delta_file).limit(10).createOrReplaceTempView("delta_test_view") spark.sql( dedent( diff --git a/tests/spark/integrations/dq/test_spark_expectations.py b/tests/spark/integrations/dq/test_spark_expectations.py index a8ef6fb..259af00 100644 --- a/tests/spark/integrations/dq/test_spark_expectations.py +++ b/tests/spark/integrations/dq/test_spark_expectations.py @@ -1,10 +1,12 @@ from typing import List, Union -import pyspark import pytest -from koheesio.utils import get_project_root + +import pyspark from pyspark.sql import SparkSession +from koheesio.utils import get_project_root + PROJECT_ROOT = get_project_root() pytestmark = pytest.mark.spark diff --git a/tests/spark/readers/test_spark_excel.py b/tests/spark/readers/test_spark_excel.py new file mode 100644 index 0000000..1a6d34a --- /dev/null +++ b/tests/spark/readers/test_spark_excel.py @@ -0,0 +1,27 @@ +import datetime +from pathlib import Path + +from koheesio.spark.readers.excel import ExcelReader + + +def test_excel_reader(spark, data_path): + # Define the path to the test Excel file and the expected DataFrame + test_file = Path(data_path) / "readers" / "excel_file" / "dummy.xlsx" + + # Initialize the ExcelReader with the path to the test file and the correct sheet name + reader = ExcelReader(path=test_file, sheet_name="sheet_to_select", header=0) + + # Execute the reader + reader.execute() + + # Define the expected DataFrame + expected_df = spark.createDataFrame( + [ + ("foo", "bar", "baz", None, datetime.datetime(2024, 1, 1, 0, 0)), + ("so long", "and thanks", "for all the fish", 42, datetime.datetime(2024, 1, 1, 0, 0)), + ], + ["a", "b", "c", "d", "e"], + ) + + # Assert that the output DataFrame is as expected + assert sorted(reader.output.df.collect()) == sorted(expected_df.collect()) diff --git a/tests/spark/test_delta.py b/tests/spark/test_delta.py index 7b0b100..5bad4a7 100644 --- a/tests/spark/test_delta.py +++ b/tests/spark/test_delta.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from unittest.mock import patch import pytest @@ -74,8 +75,8 @@ def test_table(value, expected): log.info("delta test completed") -def test_delta_table_properties(spark, setup): - setup_test_data(spark=spark) +def test_delta_table_properties(spark, setup, delta_file): + setup_test_data(spark=spark, delta_file=Path(delta_file)) table_name = "delta_test_table" dt = DeltaTableStep( table=table_name, diff --git a/tests/spark/test_spark_utils.py b/tests/spark/test_spark_utils.py new file mode 100644 index 0000000..6455bea --- /dev/null +++ b/tests/spark/test_spark_utils.py @@ -0,0 +1,53 @@ +from os import environ +from unittest.mock import patch + +import pytest + +from pyspark.sql.types import StringType, StructField, StructType + +from koheesio.spark.utils import ( + import_pandas_based_on_pyspark_version, + on_databricks, + schema_struct_to_schema_str, +) + + +def test_schema_struct_to_schema_str(): + struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())]) + val = schema_struct_to_schema_str(struct_schema) + assert val == "a STRING,\nb STRING" + assert schema_struct_to_schema_str(None) == "" + + +@pytest.mark.parametrize( + "env_var_value, expected_result", + [("lts_11_spark_3_scala_2.12", True), ("unit_test", True), (None, False)], +) +def test_on_databricks(env_var_value, expected_result): + if env_var_value is not None: + with patch.dict(environ, {"DATABRICKS_RUNTIME_VERSION": env_var_value}): + assert on_databricks() == expected_result + else: + with patch.dict(environ, clear=True): + assert on_databricks() == expected_result + + +@pytest.mark.parametrize( + "spark_version, pandas_version, expected_error", + [ + (3.3, "1.2.3", None), # PySpark 3.3, pandas < 2, should not raise an error + (3.4, "2.3.4", None), # PySpark not 3.3, pandas >= 2, should not raise an error + (3.3, "2.3.4", ImportError), # PySpark 3.3, pandas >= 2, should raise an error + (3.4, "1.2.3", ImportError), # PySpark not 3.3, pandas < 2, should raise an error + ], +) +def test_import_pandas_based_on_pyspark_version(spark_version, pandas_version, expected_error): + with ( + patch("koheesio.spark.utils.get_spark_minor_version", return_value=spark_version), + patch("pandas.__version__", new=pandas_version), + ): + if expected_error: + with pytest.raises(expected_error): + import_pandas_based_on_pyspark_version() + else: + import_pandas_based_on_pyspark_version() # This should not raise an error diff --git a/tests/spark/transformations/test_sql_transform.py b/tests/spark/transformations/test_sql_transform.py index d9822c5..9834d9c 100644 --- a/tests/spark/transformations/test_sql_transform.py +++ b/tests/spark/transformations/test_sql_transform.py @@ -1,7 +1,7 @@ +from pathlib import Path from textwrap import dedent import pytest -from conftest import TEST_DATA_PATH from koheesio.logger import LoggingFactory from koheesio.spark.transformations.sql_transform import SqlTransform @@ -11,6 +11,11 @@ log = LoggingFactory.get_logger(name="test_sql_transform") +@pytest.fixture +def test_data_path(data_path) -> Path: + return Path(data_path) / "transformations" + + @pytest.mark.parametrize( "input_values,expected", [ @@ -48,7 +53,7 @@ # input values dict( table_name="dummy_table", - sql_path=TEST_DATA_PATH / "transformations" / "dummy.sql", + sql_path="dummy.sql", ), # expected output {"id": 0, "incremented_id": 1}, @@ -58,14 +63,16 @@ # input values dict( table_name="dummy_table", - sql_path=str((TEST_DATA_PATH / "transformations" / "dummy.sql").as_posix()), + sql_path="dummy.sql", ), # expected output {"id": 0, "incremented_id": 1}, ), ], ) -def test_sql_transform(input_values, expected, dummy_df): +def test_sql_transform(input_values, expected, dummy_df, test_data_path): + if sql_path := input_values.get("sql_path"): + input_values["sql_path"] = str((test_data_path / sql_path).as_posix()) result = SqlTransform(**input_values).transform(dummy_df) actual = result.head().asDict() diff --git a/tests/spark/writers/delta/test_delta_writer.py b/tests/spark/writers/delta/test_delta_writer.py index 21d0272..4a36069 100644 --- a/tests/spark/writers/delta/test_delta_writer.py +++ b/tests/spark/writers/delta/test_delta_writer.py @@ -4,14 +4,17 @@ import pytest from conftest import await_job_completion from delta import DeltaTable + +from pydantic import ValidationError + +from pyspark.sql import functions as F + from koheesio.spark import AnalysisException from koheesio.spark.delta import DeltaTableStep from koheesio.spark.writers import BatchOutputMode, StreamingOutputMode from koheesio.spark.writers.delta import DeltaTableStreamWriter, DeltaTableWriter from koheesio.spark.writers.delta.utils import log_clauses from koheesio.spark.writers.stream import Trigger -from pydantic import ValidationError -from pyspark.sql import functions as F pytestmark = pytest.mark.spark diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 9039891..ead642f 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -25,26 +25,6 @@ def test_import_class(): assert import_class("datetime.datetime") == datetime.datetime -@pytest.mark.parametrize( - "env_var_value, expected_result", - [("lts_11_spark_3_scala_2.12", True), ("unit_test", True), (None, False)], -) -def test_on_databricks(env_var_value, expected_result): - if env_var_value is not None: - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": env_var_value}): - assert on_databricks() == expected_result - else: - with patch.dict(os.environ, clear=True): - assert on_databricks() == expected_result - - -def test_schema_struct_to_schema_str(): - struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())]) - val = schema_struct_to_schema_str(struct_schema) - assert val == "a STRING,\nb STRING" - assert schema_struct_to_schema_str(None) == "" - - def test_get_random_string(): assert get_random_string(10) != get_random_string(10) assert len(get_random_string(10)) == 10