diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 581ad0206..cb3332edb 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,12 +1,16 @@ # Upcoming Release -* Removed support for Python 3.7 ## Major features and improvements +* Removed support for Python 3.7 +* Spark and Databricks based datasets now support [databricks-connect>=13.0](https://docs.databricks.com/en/dev-tools/databricks-connect-ref.html) + ## Bug fixes and other changes * Fixed bug with loading models saved with `TensorFlowModelDataset`. + ## Community contributions Many thanks to the following Kedroids for contributing PRs to this release: * [Edouard59](https://github.com/Edouard59) +* [Miguel Rodriguez Gutierrez](https://github.com/MigQ2) # Release 1.8.0 ## Major features and improvements diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 0af44ecf1..0e7d01128 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -9,12 +9,13 @@ import pandas as pd from kedro.io.core import Version, VersionNotFoundError -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import DataFrame from pyspark.sql.types import StructType from pyspark.sql.utils import AnalysisException, ParseException from kedro_datasets import KedroDeprecationWarning from kedro_datasets._io import AbstractVersionedDataset, DatasetError +from kedro_datasets.spark.spark_dataset import _get_spark logger = logging.getLogger(__name__) @@ -264,10 +265,6 @@ def __init__( # noqa: PLR0913 exists_function=self._exists, ) - @staticmethod - def _get_spark() -> SparkSession: - return SparkSession.builder.getOrCreate() - def _load(self) -> Union[DataFrame, pd.DataFrame]: """Loads the version of data in the format defined in the init (spark|pandas dataframe) @@ -283,7 +280,7 @@ def _load(self) -> Union[DataFrame, pd.DataFrame]: if self._version and self._version.load >= 0: try: data = ( - self._get_spark() + _get_spark() .read.format("delta") .option("versionAsOf", self._version.load) .table(self._table.full_table_location()) @@ -291,7 +288,7 @@ def _load(self) -> Union[DataFrame, pd.DataFrame]: except Exception as exc: raise VersionNotFoundError(self._version.load) from exc else: - data = self._get_spark().table(self._table.full_table_location()) + data = _get_spark().table(self._table.full_table_location()) if self._table.dataframe_type == "pandas": data = data.toPandas() return data @@ -329,7 +326,7 @@ def _save_upsert(self, update_data: DataFrame) -> None: update_data (DataFrame): the Spark dataframe to upsert """ if self._exists(): - base_data = self._get_spark().table(self._table.full_table_location()) + base_data = _get_spark().table(self._table.full_table_location()) base_columns = base_data.columns update_columns = update_data.columns @@ -352,13 +349,11 @@ def _save_upsert(self, update_data: DataFrame) -> None: ) update_data.createOrReplaceTempView("update") - self._get_spark().conf.set( - "fullTableAddress", self._table.full_table_location() - ) - self._get_spark().conf.set("whereExpr", where_expr) + _get_spark().conf.set("fullTableAddress", self._table.full_table_location()) + _get_spark().conf.set("whereExpr", where_expr) upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr} WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *""" - self._get_spark().sql(upsert_sql) + _get_spark().sql(upsert_sql) else: self._save_append(update_data) @@ -380,13 +375,13 @@ def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None: if self._table.schema(): cols = self._table.schema().fieldNames() if self._table.dataframe_type == "pandas": - data = self._get_spark().createDataFrame( + data = _get_spark().createDataFrame( data.loc[:, cols], schema=self._table.schema() ) else: data = data.select(*cols) elif self._table.dataframe_type == "pandas": - data = self._get_spark().createDataFrame(data) + data = _get_spark().createDataFrame(data) if self._table.write_mode == "overwrite": self._save_overwrite(data) elif self._table.write_mode == "upsert": @@ -421,7 +416,7 @@ def _exists(self) -> bool: """ if self._table.catalog: try: - self._get_spark().sql(f"USE CATALOG `{self._table.catalog}`") + _get_spark().sql(f"USE CATALOG `{self._table.catalog}`") except (ParseException, AnalysisException) as exc: logger.warning( "catalog %s not found or unity not enabled. Error message: %s", @@ -430,7 +425,7 @@ def _exists(self) -> bool: ) try: return ( - self._get_spark() + _get_spark() .sql(f"SHOW TABLES IN `{self._table.database}`") .filter(f"tableName = '{self._table.table}'") .count() diff --git a/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py b/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py index 7a294554c..fb52e79ca 100644 --- a/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py @@ -6,12 +6,15 @@ from typing import Any, Dict, NoReturn from delta.tables import DeltaTable -from pyspark.sql import SparkSession from pyspark.sql.utils import AnalysisException from kedro_datasets import KedroDeprecationWarning from kedro_datasets._io import AbstractDataset, DatasetError -from kedro_datasets.spark.spark_dataset import _split_filepath, _strip_dbfs_prefix +from kedro_datasets.spark.spark_dataset import ( + _get_spark, + _split_filepath, + _strip_dbfs_prefix, +) class DeltaTableDataset(AbstractDataset[None, DeltaTable]): @@ -81,13 +84,9 @@ def __init__(self, filepath: str, metadata: Dict[str, Any] = None) -> None: self._filepath = PurePosixPath(filepath) self.metadata = metadata - @staticmethod - def _get_spark(): - return SparkSession.builder.getOrCreate() - def _load(self) -> DeltaTable: load_path = self._fs_prefix + str(self._filepath) - return DeltaTable.forPath(self._get_spark(), load_path) + return DeltaTable.forPath(_get_spark(), load_path) def _save(self, data: None) -> NoReturn: raise DatasetError(f"{self.__class__.__name__} is a read only dataset type") @@ -96,7 +95,7 @@ def _exists(self) -> bool: load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) try: - self._get_spark().read.load(path=load_path, format="delta") + _get_spark().read.load(path=load_path, format="delta") except AnalysisException as exception: # `AnalysisException.desc` is deprecated with pyspark >= 3.4 message = exception.desc if hasattr(exception, "desc") else str(exception) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index 58df800c8..e43404ff9 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -31,6 +31,29 @@ logger = logging.getLogger(__name__) +def _get_spark() -> Any: + """ + Returns the SparkSession. In case databricks-connect is available we use it for + extended configuration mechanisms and notebook compatibility, + otherwise we use classic pyspark. + """ + try: + # When using databricks-connect >= 13.0.0 (a.k.a databricks-connect-v2) + # the remote session is instantiated using the databricks module + # If the databricks-connect module is installed, we use a remote session + from databricks.connect import DatabricksSession + + # We can't test this as there's no Databricks test env available + spark = DatabricksSession.builder.getOrCreate() # pragma: no cover + + except ImportError: + # For "normal" spark sessions that don't use databricks-connect + # we get spark normally + spark = SparkSession.builder.getOrCreate() + + return spark + + def _parse_glob_pattern(pattern: str) -> str: special = ("*", "?", "[") clean = [] @@ -324,7 +347,7 @@ def __init__( # noqa: PLR0913 elif filepath.startswith("/dbfs/"): # dbfs add prefix to Spark path by default # See https://github.com/kedro-org/kedro-plugins/issues/117 - dbutils = _get_dbutils(self._get_spark()) + dbutils = _get_dbutils(_get_spark()) if dbutils: glob_function = partial(_dbfs_glob, dbutils=dbutils) exists_function = partial(_dbfs_exists, dbutils=dbutils) @@ -392,13 +415,9 @@ def _describe(self) -> Dict[str, Any]: "version": self._version, } - @staticmethod - def _get_spark(): - return SparkSession.builder.getOrCreate() - def _load(self) -> DataFrame: load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) - read_obj = self._get_spark().read + read_obj = _get_spark().read # Pass schema if defined if self._schema: @@ -414,7 +433,7 @@ def _exists(self) -> bool: load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) try: - self._get_spark().read.load(load_path, self._file_format) + _get_spark().read.load(load_path, self._file_format) except AnalysisException as exception: # `AnalysisException.desc` is deprecated with pyspark >= 3.4 message = exception.desc if hasattr(exception, "desc") else str(exception) diff --git a/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py index b7bd3363c..fe9bcadf9 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py @@ -6,11 +6,12 @@ from copy import deepcopy from typing import Any, Dict, List -from pyspark.sql import DataFrame, SparkSession, Window +from pyspark.sql import DataFrame, Window from pyspark.sql.functions import col, lit, row_number from kedro_datasets import KedroDeprecationWarning from kedro_datasets._io import AbstractDataset, DatasetError +from kedro_datasets.spark.spark_dataset import _get_spark class SparkHiveDataset(AbstractDataset[DataFrame, DataFrame]): @@ -137,20 +138,6 @@ def _describe(self) -> Dict[str, Any]: "format": self._format, } - @staticmethod - def _get_spark() -> SparkSession: - """ - This method should only be used to get an existing SparkSession - with valid Hive configuration. - Configuration for Hive is read from hive-site.xml on the classpath. - It supports running both SQL and HiveQL commands. - Additionally, if users are leveraging the `upsert` functionality, - then a `checkpoint` directory must be set, e.g. using - `spark.sparkContext.setCheckpointDir("/path/to/dir")` - """ - _spark = SparkSession.builder.getOrCreate() - return _spark - def _create_hive_table(self, data: DataFrame, mode: str = None): _mode: str = mode or self._write_mode data.write.saveAsTable( @@ -161,7 +148,7 @@ def _create_hive_table(self, data: DataFrame, mode: str = None): ) def _load(self) -> DataFrame: - return self._get_spark().read.table(self._full_table_address) + return _get_spark().read.table(self._full_table_address) def _save(self, data: DataFrame) -> None: self._validate_save(data) @@ -213,7 +200,7 @@ def _validate_save(self, data: DataFrame): def _exists(self) -> bool: return ( - self._get_spark() + _get_spark() ._jsparkSession.catalog() .tableExists(self._database, self._table) ) diff --git a/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py index 029cf15b5..7d84c1f90 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py @@ -3,10 +3,11 @@ from copy import deepcopy from typing import Any, Dict -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import DataFrame from kedro_datasets import KedroDeprecationWarning from kedro_datasets._io import AbstractDataset, DatasetError +from kedro_datasets.spark.spark_dataset import _get_spark class SparkJDBCDataset(AbstractDataset[DataFrame, DataFrame]): @@ -169,12 +170,8 @@ def _describe(self) -> Dict[str, Any]: "save_args": save_args, } - @staticmethod - def _get_spark(): # pragma: no cover - return SparkSession.builder.getOrCreate() - def _load(self) -> DataFrame: - return self._get_spark().read.jdbc(self._url, self._table, **self._load_args) + return _get_spark().read.jdbc(self._url, self._table, **self._load_args) def _save(self, data: DataFrame) -> None: return data.write.jdbc(self._url, self._table, **self._save_args) diff --git a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py index 7ebe84ae4..cea59adc7 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py @@ -4,13 +4,14 @@ from pathlib import PurePosixPath from typing import Any, Dict -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import DataFrame from pyspark.sql.utils import AnalysisException from kedro_datasets import KedroDeprecationWarning from kedro_datasets._io import AbstractDataset from kedro_datasets.spark.spark_dataset import ( SparkDataset, + _get_spark, _split_filepath, _strip_dbfs_prefix, ) @@ -104,10 +105,6 @@ def _describe(self) -> Dict[str, Any]: "save_args": self._save_args, } - @staticmethod - def _get_spark(): - return SparkSession.builder.getOrCreate() - def _load(self) -> DataFrame: """Loads data from filepath. If the connector type is kafka then no file_path is required, schema needs to be @@ -117,7 +114,7 @@ def _load(self) -> DataFrame: """ load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) data_stream_reader = ( - self._get_spark() + _get_spark() .readStream.schema(self._schema) .format(self._file_format) .options(**self._load_args) @@ -146,7 +143,7 @@ def _exists(self) -> bool: load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) try: - self._get_spark().readStream.schema(self._schema).load( + _get_spark().readStream.schema(self._schema).load( load_path, self._file_format ) except AnalysisException as exception: diff --git a/kedro-datasets/tests/spark/test_deltatable_dataset.py b/kedro-datasets/tests/spark/test_deltatable_dataset.py index 58940f5ce..525d2cdb0 100644 --- a/kedro-datasets/tests/spark/test_deltatable_dataset.py +++ b/kedro-datasets/tests/spark/test_deltatable_dataset.py @@ -86,13 +86,13 @@ def test_exists(self, tmp_path, sample_spark_df): def test_exists_raises_error(self, mocker): delta_ds = DeltaTableDataset(filepath="") if SPARK_VERSION >= Version("3.4.0"): - mocker.patch.object( - delta_ds, "_get_spark", side_effect=AnalysisException("Other Exception") + mocker.patch( + "kedro_datasets.spark.deltatable_dataset._get_spark", + side_effect=AnalysisException("Other Exception"), ) else: - mocker.patch.object( - delta_ds, - "_get_spark", + mocker.patch( + "kedro_datasets.spark.deltatable_dataset._get_spark", side_effect=AnalysisException("Other Exception", []), ) with pytest.raises(DatasetError, match="Other Exception"): diff --git a/kedro-datasets/tests/spark/test_spark_dataset.py b/kedro-datasets/tests/spark/test_spark_dataset.py index 032c2a0ee..44358adb2 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_dataset.py @@ -423,15 +423,13 @@ def test_exists_raises_error(self, mocker): # AnalysisExceptions clearly indicating a missing file spark_dataset = SparkDataset(filepath="") if SPARK_VERSION >= PackagingVersion("3.4.0"): - mocker.patch.object( - spark_dataset, - "_get_spark", + mocker.patch( + "kedro_datasets.spark.spark_dataset._get_spark", side_effect=AnalysisException("Other Exception"), ) else: - mocker.patch.object( - spark_dataset, - "_get_spark", + mocker.patch( + "kedro_datasets.spark.spark_dataset._get_spark", side_effect=AnalysisException("Other Exception", []), ) with pytest.raises(DatasetError, match="Other Exception"): @@ -748,7 +746,9 @@ def test_no_version(self, versioned_dataset_s3): versioned_dataset_s3.load() def test_load_latest(self, mocker, versioned_dataset_s3): - get_spark = mocker.patch.object(versioned_dataset_s3, "_get_spark") + get_spark = mocker.patch( + "kedro_datasets.spark.spark_dataset._get_spark", + ) mocked_glob = mocker.patch.object(versioned_dataset_s3, "_glob_function") mocked_glob.return_value = [ "{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v="mocked_version") @@ -771,8 +771,9 @@ def test_load_exact(self, mocker): filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", version=Version(ts, None), ) - get_spark = mocker.patch.object(ds_s3, "_get_spark") - + get_spark = mocker.patch( + "kedro_datasets.spark.spark_dataset._get_spark", + ) ds_s3.load() get_spark.return_value.read.load.assert_called_once_with( @@ -864,7 +865,10 @@ def test_load_latest(self, mocker, version): hdfs_walk.return_value = HDFS_FOLDER_STRUCTURE versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") + + get_spark = mocker.patch( + "kedro_datasets.spark.spark_dataset._get_spark", + ) versioned_hdfs.load() @@ -881,7 +885,9 @@ def test_load_exact(self, mocker): versioned_hdfs = SparkDataset( filepath=f"hdfs://{HDFS_PREFIX}", version=Version(ts, None) ) - get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") + get_spark = mocker.patch( + "kedro_datasets.spark.spark_dataset._get_spark", + ) versioned_hdfs.load() diff --git a/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py b/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py index e9bb33ddb..3e11a0877 100644 --- a/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py @@ -102,14 +102,18 @@ def test_except_bad_credentials(mocker, spark_jdbc_args_credentials_with_none_pa def test_load(mocker, spark_jdbc_args): - spark = mocker.patch.object(SparkJDBCDataset, "_get_spark").return_value + spark = mocker.patch( + "kedro_datasets.spark.spark_jdbc_dataset._get_spark" + ).return_value dataset = SparkJDBCDataset(**spark_jdbc_args) dataset.load() spark.read.jdbc.assert_called_with("dummy_url", "dummy_table") def test_load_credentials(mocker, spark_jdbc_args_credentials): - spark = mocker.patch.object(SparkJDBCDataset, "_get_spark").return_value + spark = mocker.patch( + "kedro_datasets.spark.spark_jdbc_dataset._get_spark" + ).return_value dataset = SparkJDBCDataset(**spark_jdbc_args_credentials) dataset.load() spark.read.jdbc.assert_called_with( @@ -120,7 +124,9 @@ def test_load_credentials(mocker, spark_jdbc_args_credentials): def test_load_args(mocker, spark_jdbc_args_save_load): - spark = mocker.patch.object(SparkJDBCDataset, "_get_spark").return_value + spark = mocker.patch( + "kedro_datasets.spark.spark_jdbc_dataset._get_spark" + ).return_value dataset = SparkJDBCDataset(**spark_jdbc_args_save_load) dataset.load() spark.read.jdbc.assert_called_with( diff --git a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py index d199df812..236d076a4 100644 --- a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py @@ -188,15 +188,13 @@ def test_exists_raises_error(self, mocker): spark_dataset = SparkStreamingDataset(filepath="") if SPARK_VERSION >= Version("3.4.0"): - mocker.patch.object( - spark_dataset, - "_get_spark", + mocker.patch( + "kedro_datasets.spark.spark_streaming_dataset._get_spark", side_effect=AnalysisException("Other Exception"), ) else: - mocker.patch.object( - spark_dataset, - "_get_spark", + mocker.patch( + "kedro_datasets.spark.spark_streaming_dataset._get_spark", side_effect=AnalysisException("Other Exception", []), ) with pytest.raises(DatasetError, match="Other Exception"):