Skip to content

Commit

Permalink
feat(datasets): Add support for databricks-connect>=13.0 (#352)
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Rodriguez Gutierrez <[email protected]>
  • Loading branch information
MigQ2 authored Nov 1, 2023
1 parent 08214ad commit 16c216c
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 88 deletions.
6 changes: 5 additions & 1 deletion kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Upcoming Release
* Removed support for Python 3.7

## Major features and improvements
* Removed support for Python 3.7
* Spark and Databricks based datasets now support [databricks-connect>=13.0](https://docs.databricks.com/en/dev-tools/databricks-connect-ref.html)

## Bug fixes and other changes
* Fixed bug with loading models saved with `TensorFlowModelDataset`.

## Community contributions
Many thanks to the following Kedroids for contributing PRs to this release:
* [Edouard59](https://github.com/Edouard59)
* [Miguel Rodriguez Gutierrez](https://github.com/MigQ2)

# Release 1.8.0
## Major features and improvements
Expand Down
29 changes: 12 additions & 17 deletions kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

import pandas as pd
from kedro.io.core import Version, VersionNotFoundError
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType
from pyspark.sql.utils import AnalysisException, ParseException

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractVersionedDataset, DatasetError
from kedro_datasets.spark.spark_dataset import _get_spark

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -264,10 +265,6 @@ def __init__( # noqa: PLR0913
exists_function=self._exists,
)

@staticmethod
def _get_spark() -> SparkSession:
return SparkSession.builder.getOrCreate()

def _load(self) -> Union[DataFrame, pd.DataFrame]:
"""Loads the version of data in the format defined in the init
(spark|pandas dataframe)
Expand All @@ -283,15 +280,15 @@ def _load(self) -> Union[DataFrame, pd.DataFrame]:
if self._version and self._version.load >= 0:
try:
data = (
self._get_spark()
_get_spark()
.read.format("delta")
.option("versionAsOf", self._version.load)
.table(self._table.full_table_location())
)
except Exception as exc:
raise VersionNotFoundError(self._version.load) from exc
else:
data = self._get_spark().table(self._table.full_table_location())
data = _get_spark().table(self._table.full_table_location())
if self._table.dataframe_type == "pandas":
data = data.toPandas()
return data
Expand Down Expand Up @@ -329,7 +326,7 @@ def _save_upsert(self, update_data: DataFrame) -> None:
update_data (DataFrame): the Spark dataframe to upsert
"""
if self._exists():
base_data = self._get_spark().table(self._table.full_table_location())
base_data = _get_spark().table(self._table.full_table_location())
base_columns = base_data.columns
update_columns = update_data.columns

Expand All @@ -352,13 +349,11 @@ def _save_upsert(self, update_data: DataFrame) -> None:
)

update_data.createOrReplaceTempView("update")
self._get_spark().conf.set(
"fullTableAddress", self._table.full_table_location()
)
self._get_spark().conf.set("whereExpr", where_expr)
_get_spark().conf.set("fullTableAddress", self._table.full_table_location())
_get_spark().conf.set("whereExpr", where_expr)
upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr}
WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *"""
self._get_spark().sql(upsert_sql)
_get_spark().sql(upsert_sql)
else:
self._save_append(update_data)

Expand All @@ -380,13 +375,13 @@ def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None:
if self._table.schema():
cols = self._table.schema().fieldNames()
if self._table.dataframe_type == "pandas":
data = self._get_spark().createDataFrame(
data = _get_spark().createDataFrame(
data.loc[:, cols], schema=self._table.schema()
)
else:
data = data.select(*cols)
elif self._table.dataframe_type == "pandas":
data = self._get_spark().createDataFrame(data)
data = _get_spark().createDataFrame(data)
if self._table.write_mode == "overwrite":
self._save_overwrite(data)
elif self._table.write_mode == "upsert":
Expand Down Expand Up @@ -421,7 +416,7 @@ def _exists(self) -> bool:
"""
if self._table.catalog:
try:
self._get_spark().sql(f"USE CATALOG `{self._table.catalog}`")
_get_spark().sql(f"USE CATALOG `{self._table.catalog}`")
except (ParseException, AnalysisException) as exc:
logger.warning(
"catalog %s not found or unity not enabled. Error message: %s",
Expand All @@ -430,7 +425,7 @@ def _exists(self) -> bool:
)
try:
return (
self._get_spark()
_get_spark()
.sql(f"SHOW TABLES IN `{self._table.database}`")
.filter(f"tableName = '{self._table.table}'")
.count()
Expand Down
15 changes: 7 additions & 8 deletions kedro-datasets/kedro_datasets/spark/deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from typing import Any, Dict, NoReturn

from delta.tables import DeltaTable
from pyspark.sql import SparkSession
from pyspark.sql.utils import AnalysisException

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractDataset, DatasetError
from kedro_datasets.spark.spark_dataset import _split_filepath, _strip_dbfs_prefix
from kedro_datasets.spark.spark_dataset import (
_get_spark,
_split_filepath,
_strip_dbfs_prefix,
)


class DeltaTableDataset(AbstractDataset[None, DeltaTable]):
Expand Down Expand Up @@ -81,13 +84,9 @@ def __init__(self, filepath: str, metadata: Dict[str, Any] = None) -> None:
self._filepath = PurePosixPath(filepath)
self.metadata = metadata

@staticmethod
def _get_spark():
return SparkSession.builder.getOrCreate()

def _load(self) -> DeltaTable:
load_path = self._fs_prefix + str(self._filepath)
return DeltaTable.forPath(self._get_spark(), load_path)
return DeltaTable.forPath(_get_spark(), load_path)

def _save(self, data: None) -> NoReturn:
raise DatasetError(f"{self.__class__.__name__} is a read only dataset type")
Expand All @@ -96,7 +95,7 @@ def _exists(self) -> bool:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath))

try:
self._get_spark().read.load(path=load_path, format="delta")
_get_spark().read.load(path=load_path, format="delta")
except AnalysisException as exception:
# `AnalysisException.desc` is deprecated with pyspark >= 3.4
message = exception.desc if hasattr(exception, "desc") else str(exception)
Expand Down
33 changes: 26 additions & 7 deletions kedro-datasets/kedro_datasets/spark/spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,29 @@
logger = logging.getLogger(__name__)


def _get_spark() -> Any:
"""
Returns the SparkSession. In case databricks-connect is available we use it for
extended configuration mechanisms and notebook compatibility,
otherwise we use classic pyspark.
"""
try:
# When using databricks-connect >= 13.0.0 (a.k.a databricks-connect-v2)
# the remote session is instantiated using the databricks module
# If the databricks-connect module is installed, we use a remote session
from databricks.connect import DatabricksSession

# We can't test this as there's no Databricks test env available
spark = DatabricksSession.builder.getOrCreate() # pragma: no cover

except ImportError:
# For "normal" spark sessions that don't use databricks-connect
# we get spark normally
spark = SparkSession.builder.getOrCreate()

return spark


def _parse_glob_pattern(pattern: str) -> str:
special = ("*", "?", "[")
clean = []
Expand Down Expand Up @@ -324,7 +347,7 @@ def __init__( # noqa: PLR0913
elif filepath.startswith("/dbfs/"):
# dbfs add prefix to Spark path by default
# See https://github.com/kedro-org/kedro-plugins/issues/117
dbutils = _get_dbutils(self._get_spark())
dbutils = _get_dbutils(_get_spark())
if dbutils:
glob_function = partial(_dbfs_glob, dbutils=dbutils)
exists_function = partial(_dbfs_exists, dbutils=dbutils)
Expand Down Expand Up @@ -392,13 +415,9 @@ def _describe(self) -> Dict[str, Any]:
"version": self._version,
}

@staticmethod
def _get_spark():
return SparkSession.builder.getOrCreate()

def _load(self) -> DataFrame:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path()))
read_obj = self._get_spark().read
read_obj = _get_spark().read

# Pass schema if defined
if self._schema:
Expand All @@ -414,7 +433,7 @@ def _exists(self) -> bool:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path()))

try:
self._get_spark().read.load(load_path, self._file_format)
_get_spark().read.load(load_path, self._file_format)
except AnalysisException as exception:
# `AnalysisException.desc` is deprecated with pyspark >= 3.4
message = exception.desc if hasattr(exception, "desc") else str(exception)
Expand Down
21 changes: 4 additions & 17 deletions kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from copy import deepcopy
from typing import Any, Dict, List

from pyspark.sql import DataFrame, SparkSession, Window
from pyspark.sql import DataFrame, Window
from pyspark.sql.functions import col, lit, row_number

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractDataset, DatasetError
from kedro_datasets.spark.spark_dataset import _get_spark


class SparkHiveDataset(AbstractDataset[DataFrame, DataFrame]):
Expand Down Expand Up @@ -137,20 +138,6 @@ def _describe(self) -> Dict[str, Any]:
"format": self._format,
}

@staticmethod
def _get_spark() -> SparkSession:
"""
This method should only be used to get an existing SparkSession
with valid Hive configuration.
Configuration for Hive is read from hive-site.xml on the classpath.
It supports running both SQL and HiveQL commands.
Additionally, if users are leveraging the `upsert` functionality,
then a `checkpoint` directory must be set, e.g. using
`spark.sparkContext.setCheckpointDir("/path/to/dir")`
"""
_spark = SparkSession.builder.getOrCreate()
return _spark

def _create_hive_table(self, data: DataFrame, mode: str = None):
_mode: str = mode or self._write_mode
data.write.saveAsTable(
Expand All @@ -161,7 +148,7 @@ def _create_hive_table(self, data: DataFrame, mode: str = None):
)

def _load(self) -> DataFrame:
return self._get_spark().read.table(self._full_table_address)
return _get_spark().read.table(self._full_table_address)

def _save(self, data: DataFrame) -> None:
self._validate_save(data)
Expand Down Expand Up @@ -213,7 +200,7 @@ def _validate_save(self, data: DataFrame):

def _exists(self) -> bool:
return (
self._get_spark()
_get_spark()
._jsparkSession.catalog()
.tableExists(self._database, self._table)
)
Expand Down
9 changes: 3 additions & 6 deletions kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from copy import deepcopy
from typing import Any, Dict

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractDataset, DatasetError
from kedro_datasets.spark.spark_dataset import _get_spark


class SparkJDBCDataset(AbstractDataset[DataFrame, DataFrame]):
Expand Down Expand Up @@ -169,12 +170,8 @@ def _describe(self) -> Dict[str, Any]:
"save_args": save_args,
}

@staticmethod
def _get_spark(): # pragma: no cover
return SparkSession.builder.getOrCreate()

def _load(self) -> DataFrame:
return self._get_spark().read.jdbc(self._url, self._table, **self._load_args)
return _get_spark().read.jdbc(self._url, self._table, **self._load_args)

def _save(self, data: DataFrame) -> None:
return data.write.jdbc(self._url, self._table, **self._save_args)
Expand Down
11 changes: 4 additions & 7 deletions kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from pathlib import PurePosixPath
from typing import Any, Dict

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.utils import AnalysisException

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractDataset
from kedro_datasets.spark.spark_dataset import (
SparkDataset,
_get_spark,
_split_filepath,
_strip_dbfs_prefix,
)
Expand Down Expand Up @@ -104,10 +105,6 @@ def _describe(self) -> Dict[str, Any]:
"save_args": self._save_args,
}

@staticmethod
def _get_spark():
return SparkSession.builder.getOrCreate()

def _load(self) -> DataFrame:
"""Loads data from filepath.
If the connector type is kafka then no file_path is required, schema needs to be
Expand All @@ -117,7 +114,7 @@ def _load(self) -> DataFrame:
"""
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath))
data_stream_reader = (
self._get_spark()
_get_spark()
.readStream.schema(self._schema)
.format(self._file_format)
.options(**self._load_args)
Expand Down Expand Up @@ -146,7 +143,7 @@ def _exists(self) -> bool:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath))

try:
self._get_spark().readStream.schema(self._schema).load(
_get_spark().readStream.schema(self._schema).load(
load_path, self._file_format
)
except AnalysisException as exception:
Expand Down
10 changes: 5 additions & 5 deletions kedro-datasets/tests/spark/test_deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def test_exists(self, tmp_path, sample_spark_df):
def test_exists_raises_error(self, mocker):
delta_ds = DeltaTableDataset(filepath="")
if SPARK_VERSION >= Version("3.4.0"):
mocker.patch.object(
delta_ds, "_get_spark", side_effect=AnalysisException("Other Exception")
mocker.patch(
"kedro_datasets.spark.deltatable_dataset._get_spark",
side_effect=AnalysisException("Other Exception"),
)
else:
mocker.patch.object(
delta_ds,
"_get_spark",
mocker.patch(
"kedro_datasets.spark.deltatable_dataset._get_spark",
side_effect=AnalysisException("Other Exception", []),
)
with pytest.raises(DatasetError, match="Other Exception"):
Expand Down
Loading

0 comments on commit 16c216c

Please sign in to comment.