Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets): Add support for databricks-connect>=13.0 #352

Merged
merged 20 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Upcoming Release
* Removed support for Python 3.7

## Major features and improvements
* Removed support for Python 3.7
* Spark and Databricks based datasets now support [databricks-connect>=13.0](https://docs.databricks.com/en/dev-tools/databricks-connect-ref.html)

## Bug fixes and other changes
* Fixed bug with loading models saved with `TensorFlowModelDataset`.

## Community contributions
Many thanks to the following Kedroids for contributing PRs to this release:
* [Edouard59](https://github.com/Edouard59)
* [Miguel Rodriguez Gutierrez](https://github.com/MigQ2)

# Release 1.8.0
## Major features and improvements
Expand Down
29 changes: 12 additions & 17 deletions kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

import pandas as pd
from kedro.io.core import Version, VersionNotFoundError
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType
from pyspark.sql.utils import AnalysisException, ParseException

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractVersionedDataset, DatasetError
from kedro_datasets.spark.spark_dataset import _get_spark

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -264,10 +265,6 @@ def __init__( # noqa: PLR0913
exists_function=self._exists,
)

@staticmethod
def _get_spark() -> SparkSession:
return SparkSession.builder.getOrCreate()

def _load(self) -> Union[DataFrame, pd.DataFrame]:
"""Loads the version of data in the format defined in the init
(spark|pandas dataframe)
Expand All @@ -283,15 +280,15 @@ def _load(self) -> Union[DataFrame, pd.DataFrame]:
if self._version and self._version.load >= 0:
try:
data = (
self._get_spark()
_get_spark()
.read.format("delta")
.option("versionAsOf", self._version.load)
.table(self._table.full_table_location())
)
except Exception as exc:
raise VersionNotFoundError(self._version.load) from exc
else:
data = self._get_spark().table(self._table.full_table_location())
data = _get_spark().table(self._table.full_table_location())
if self._table.dataframe_type == "pandas":
data = data.toPandas()
return data
Expand Down Expand Up @@ -329,7 +326,7 @@ def _save_upsert(self, update_data: DataFrame) -> None:
update_data (DataFrame): the Spark dataframe to upsert
"""
if self._exists():
base_data = self._get_spark().table(self._table.full_table_location())
base_data = _get_spark().table(self._table.full_table_location())
base_columns = base_data.columns
update_columns = update_data.columns

Expand All @@ -352,13 +349,11 @@ def _save_upsert(self, update_data: DataFrame) -> None:
)

update_data.createOrReplaceTempView("update")
self._get_spark().conf.set(
"fullTableAddress", self._table.full_table_location()
)
self._get_spark().conf.set("whereExpr", where_expr)
_get_spark().conf.set("fullTableAddress", self._table.full_table_location())
_get_spark().conf.set("whereExpr", where_expr)
upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr}
WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *"""
self._get_spark().sql(upsert_sql)
_get_spark().sql(upsert_sql)
else:
self._save_append(update_data)

Expand All @@ -380,13 +375,13 @@ def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None:
if self._table.schema():
cols = self._table.schema().fieldNames()
if self._table.dataframe_type == "pandas":
data = self._get_spark().createDataFrame(
data = _get_spark().createDataFrame(
data.loc[:, cols], schema=self._table.schema()
)
else:
data = data.select(*cols)
elif self._table.dataframe_type == "pandas":
data = self._get_spark().createDataFrame(data)
data = _get_spark().createDataFrame(data)
if self._table.write_mode == "overwrite":
self._save_overwrite(data)
elif self._table.write_mode == "upsert":
Expand Down Expand Up @@ -421,7 +416,7 @@ def _exists(self) -> bool:
"""
if self._table.catalog:
try:
self._get_spark().sql(f"USE CATALOG `{self._table.catalog}`")
_get_spark().sql(f"USE CATALOG `{self._table.catalog}`")
except (ParseException, AnalysisException) as exc:
logger.warning(
"catalog %s not found or unity not enabled. Error message: %s",
Expand All @@ -430,7 +425,7 @@ def _exists(self) -> bool:
)
try:
return (
self._get_spark()
_get_spark()
.sql(f"SHOW TABLES IN `{self._table.database}`")
.filter(f"tableName = '{self._table.table}'")
.count()
Expand Down
15 changes: 7 additions & 8 deletions kedro-datasets/kedro_datasets/spark/deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from typing import Any, Dict, NoReturn

from delta.tables import DeltaTable
from pyspark.sql import SparkSession
from pyspark.sql.utils import AnalysisException

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractDataset, DatasetError
from kedro_datasets.spark.spark_dataset import _split_filepath, _strip_dbfs_prefix
from kedro_datasets.spark.spark_dataset import (
_get_spark,
_split_filepath,
_strip_dbfs_prefix,
)


class DeltaTableDataset(AbstractDataset[None, DeltaTable]):
Expand Down Expand Up @@ -81,13 +84,9 @@ def __init__(self, filepath: str, metadata: Dict[str, Any] = None) -> None:
self._filepath = PurePosixPath(filepath)
self.metadata = metadata

@staticmethod
def _get_spark():
return SparkSession.builder.getOrCreate()

def _load(self) -> DeltaTable:
load_path = self._fs_prefix + str(self._filepath)
return DeltaTable.forPath(self._get_spark(), load_path)
return DeltaTable.forPath(_get_spark(), load_path)

def _save(self, data: None) -> NoReturn:
raise DatasetError(f"{self.__class__.__name__} is a read only dataset type")
Expand All @@ -96,7 +95,7 @@ def _exists(self) -> bool:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath))

try:
self._get_spark().read.load(path=load_path, format="delta")
_get_spark().read.load(path=load_path, format="delta")
except AnalysisException as exception:
# `AnalysisException.desc` is deprecated with pyspark >= 3.4
message = exception.desc if hasattr(exception, "desc") else str(exception)
Expand Down
33 changes: 26 additions & 7 deletions kedro-datasets/kedro_datasets/spark/spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,29 @@
logger = logging.getLogger(__name__)


def _get_spark() -> Any:
"""
Returns the SparkSession. In case databricks-connect is available we use it for
extended configuration mechanisms and notebook compatibility,
otherwise we use classic pyspark.
"""
try:
# When using databricks-connect >= 13.0.0 (a.k.a databricks-connect-v2)
# the remote session is instantiated using the databricks module
# If the databricks-connect module is installed, we use a remote session
from databricks.connect import DatabricksSession

# We can't test this as there's no Databricks test env available
spark = DatabricksSession.builder.getOrCreate() # pragma: no cover

except ImportError:
# For "normal" spark sessions that don't use databricks-connect
# we get spark normally
spark = SparkSession.builder.getOrCreate()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For even further improved compatibility, add

from pyspark.sql.connect.session import SparkSession as RemoteSparkSession

and replace this line with RemoteSparkSession.builder.getOrCreate().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for looking into this @cdkrot , super appreciated

Do you mean another try except? So the fallback waterfall would be:

  1. DatabricksSession
  2. RemoteSparkSession
  3. SparkSession

I guess adding RemoteSparksession would helo adding support for non-Databricks Spark-connect sessions?

Copy link

@cdkrot cdkrot Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my bad, the logic in code is good as is 👍. sql.SparkSession will figure out if it needs to create sql.SparkSession or connect.SparkSession.

Possibly it's best to rewrite the comments though. Essentially the only reason to fallback from dbconnect session creation to pyspark's is when first is not installed, since dbconnect switches to pyspark automatically based on environment.


return spark


def _parse_glob_pattern(pattern: str) -> str:
special = ("*", "?", "[")
clean = []
Expand Down Expand Up @@ -324,7 +347,7 @@ def __init__( # noqa: PLR0913
elif filepath.startswith("/dbfs/"):
# dbfs add prefix to Spark path by default
# See https://github.com/kedro-org/kedro-plugins/issues/117
dbutils = _get_dbutils(self._get_spark())
dbutils = _get_dbutils(_get_spark())
if dbutils:
glob_function = partial(_dbfs_glob, dbutils=dbutils)
exists_function = partial(_dbfs_exists, dbutils=dbutils)
Expand Down Expand Up @@ -392,13 +415,9 @@ def _describe(self) -> Dict[str, Any]:
"version": self._version,
}

@staticmethod
def _get_spark():
return SparkSession.builder.getOrCreate()

def _load(self) -> DataFrame:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path()))
read_obj = self._get_spark().read
read_obj = _get_spark().read

# Pass schema if defined
if self._schema:
Expand All @@ -414,7 +433,7 @@ def _exists(self) -> bool:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path()))

try:
self._get_spark().read.load(load_path, self._file_format)
_get_spark().read.load(load_path, self._file_format)
except AnalysisException as exception:
# `AnalysisException.desc` is deprecated with pyspark >= 3.4
message = exception.desc if hasattr(exception, "desc") else str(exception)
Expand Down
21 changes: 4 additions & 17 deletions kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from copy import deepcopy
from typing import Any, Dict, List

from pyspark.sql import DataFrame, SparkSession, Window
from pyspark.sql import DataFrame, Window
from pyspark.sql.functions import col, lit, row_number

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractDataset, DatasetError
from kedro_datasets.spark.spark_dataset import _get_spark


class SparkHiveDataset(AbstractDataset[DataFrame, DataFrame]):
Expand Down Expand Up @@ -137,20 +138,6 @@ def _describe(self) -> Dict[str, Any]:
"format": self._format,
}

@staticmethod
def _get_spark() -> SparkSession:
"""
This method should only be used to get an existing SparkSession
with valid Hive configuration.
Configuration for Hive is read from hive-site.xml on the classpath.
It supports running both SQL and HiveQL commands.
Additionally, if users are leveraging the `upsert` functionality,
then a `checkpoint` directory must be set, e.g. using
`spark.sparkContext.setCheckpointDir("/path/to/dir")`
"""
_spark = SparkSession.builder.getOrCreate()
return _spark

def _create_hive_table(self, data: DataFrame, mode: str = None):
_mode: str = mode or self._write_mode
data.write.saveAsTable(
Expand All @@ -161,7 +148,7 @@ def _create_hive_table(self, data: DataFrame, mode: str = None):
)

def _load(self) -> DataFrame:
return self._get_spark().read.table(self._full_table_address)
return _get_spark().read.table(self._full_table_address)

def _save(self, data: DataFrame) -> None:
self._validate_save(data)
Expand Down Expand Up @@ -213,7 +200,7 @@ def _validate_save(self, data: DataFrame):

def _exists(self) -> bool:
return (
self._get_spark()
_get_spark()
._jsparkSession.catalog()
.tableExists(self._database, self._table)
)
Expand Down
9 changes: 3 additions & 6 deletions kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from copy import deepcopy
from typing import Any, Dict

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractDataset, DatasetError
from kedro_datasets.spark.spark_dataset import _get_spark


class SparkJDBCDataset(AbstractDataset[DataFrame, DataFrame]):
Expand Down Expand Up @@ -169,12 +170,8 @@ def _describe(self) -> Dict[str, Any]:
"save_args": save_args,
}

@staticmethod
def _get_spark(): # pragma: no cover
return SparkSession.builder.getOrCreate()

def _load(self) -> DataFrame:
return self._get_spark().read.jdbc(self._url, self._table, **self._load_args)
return _get_spark().read.jdbc(self._url, self._table, **self._load_args)

def _save(self, data: DataFrame) -> None:
return data.write.jdbc(self._url, self._table, **self._save_args)
Expand Down
11 changes: 4 additions & 7 deletions kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from pathlib import PurePosixPath
from typing import Any, Dict

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.utils import AnalysisException

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractDataset
from kedro_datasets.spark.spark_dataset import (
SparkDataset,
_get_spark,
_split_filepath,
_strip_dbfs_prefix,
)
Expand Down Expand Up @@ -104,10 +105,6 @@ def _describe(self) -> Dict[str, Any]:
"save_args": self._save_args,
}

@staticmethod
def _get_spark():
return SparkSession.builder.getOrCreate()

def _load(self) -> DataFrame:
"""Loads data from filepath.
If the connector type is kafka then no file_path is required, schema needs to be
Expand All @@ -117,7 +114,7 @@ def _load(self) -> DataFrame:
"""
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath))
data_stream_reader = (
self._get_spark()
_get_spark()
.readStream.schema(self._schema)
.format(self._file_format)
.options(**self._load_args)
Expand Down Expand Up @@ -146,7 +143,7 @@ def _exists(self) -> bool:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath))

try:
self._get_spark().readStream.schema(self._schema).load(
_get_spark().readStream.schema(self._schema).load(
load_path, self._file_format
)
except AnalysisException as exception:
Expand Down
10 changes: 5 additions & 5 deletions kedro-datasets/tests/spark/test_deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def test_exists(self, tmp_path, sample_spark_df):
def test_exists_raises_error(self, mocker):
delta_ds = DeltaTableDataset(filepath="")
if SPARK_VERSION >= Version("3.4.0"):
mocker.patch.object(
delta_ds, "_get_spark", side_effect=AnalysisException("Other Exception")
mocker.patch(
"kedro_datasets.spark.deltatable_dataset._get_spark",
side_effect=AnalysisException("Other Exception"),
)
else:
mocker.patch.object(
delta_ds,
"_get_spark",
mocker.patch(
"kedro_datasets.spark.deltatable_dataset._get_spark",
side_effect=AnalysisException("Other Exception", []),
)
with pytest.raises(DatasetError, match="Other Exception"):
Expand Down
Loading