diff --git a/docs/changelog/next_release/164.bugfix.rst b/docs/changelog/next_release/164.bugfix.rst new file mode 100644 index 000000000..e9c591100 --- /dev/null +++ b/docs/changelog/next_release/164.bugfix.rst @@ -0,0 +1 @@ +Fix ``Hive.check()`` behavior when Hive Metastore is not available. diff --git a/docs/changelog/next_release/164.improvement.rst b/docs/changelog/next_release/164.improvement.rst new file mode 100644 index 000000000..09799b8d2 --- /dev/null +++ b/docs/changelog/next_release/164.improvement.rst @@ -0,0 +1 @@ +Add check to all DB and FileDF connections that Spark session is alive. diff --git a/onetl/connection/db_connection/db_connection/connection.py b/onetl/connection/db_connection/db_connection/connection.py index 315f5b17c..731ef872d 100644 --- a/onetl/connection/db_connection/db_connection/connection.py +++ b/onetl/connection/db_connection/db_connection/connection.py @@ -17,7 +17,7 @@ from logging import getLogger from typing import TYPE_CHECKING -from pydantic import Field +from pydantic import Field, validator from onetl._util.spark import try_import_pyspark from onetl.base import BaseDBConnection @@ -48,6 +48,19 @@ def _forward_refs(cls) -> dict[str, type]: refs["SparkSession"] = SparkSession return refs + @validator("spark") + def _check_spark_session_alive(cls, spark): + # https://stackoverflow.com/a/36044685 + msg = "Spark session is stopped. Please recreate Spark session." + try: + if not spark._jsc.sc().isStopped(): + return spark + except Exception as e: + # None has no attribute "something" + raise ValueError(msg) from e + + raise ValueError(msg) + def _log_parameters(self): log.info("|%s| Using connection parameters:", self.__class__.__name__) parameters = self.dict(exclude_none=True, exclude={"spark"}) diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py index 740d09b44..97cc034f5 100644 --- a/onetl/connection/db_connection/hive/connection.py +++ b/onetl/connection/db_connection/hive/connection.py @@ -146,7 +146,7 @@ class Hive(DBConnection): # TODO: remove in v1.0.0 slots = HiveSlots - _CHECK_QUERY: ClassVar[str] = "SELECT 1" + _CHECK_QUERY: ClassVar[str] = "SHOW DATABASES" @slot @classmethod @@ -207,7 +207,7 @@ def check(self): log_lines(log, self._CHECK_QUERY, level=logging.DEBUG) try: - self._execute_sql(self._CHECK_QUERY) + self._execute_sql(self._CHECK_QUERY).limit(1).collect() log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) diff --git a/onetl/connection/db_connection/mongodb/connection.py b/onetl/connection/db_connection/mongodb/connection.py index 860f7b215..280596d5d 100644 --- a/onetl/connection/db_connection/mongodb/connection.py +++ b/onetl/connection/db_connection/mongodb/connection.py @@ -507,6 +507,7 @@ def write_df_to_target( ) if self._collection_exists(target): + # MongoDB connector does not support mode=ignore and mode=error if write_options.if_exists == MongoDBCollectionExistBehavior.ERROR: raise ValueError("Operation stopped due to MongoDB.WriteOptions(if_exists='error')") elif write_options.if_exists == MongoDBCollectionExistBehavior.IGNORE: diff --git a/onetl/connection/file_df_connection/spark_file_df_connection.py b/onetl/connection/file_df_connection/spark_file_df_connection.py index 7c1994182..10853078b 100644 --- a/onetl/connection/file_df_connection/spark_file_df_connection.py +++ b/onetl/connection/file_df_connection/spark_file_df_connection.py @@ -19,7 +19,7 @@ from logging import getLogger from typing import TYPE_CHECKING -from pydantic import Field +from pydantic import Field, validator from onetl._util.hadoop import get_hadoop_config from onetl._util.spark import try_import_pyspark @@ -182,6 +182,19 @@ def _forward_refs(cls) -> dict[str, type]: refs["SparkSession"] = SparkSession return refs + @validator("spark") + def _check_spark_session_alive(cls, spark): + # https://stackoverflow.com/a/36044685 + msg = "Spark session is stopped. Please recreate Spark session." + try: + if not spark._jsc.sc().isStopped(): + return spark + except Exception as e: + # None has no attribute "something" + raise ValueError(msg) from e + + raise ValueError(msg) + def _log_parameters(self): log.info("|%s| Using connection parameters:", self.__class__.__name__) parameters = self.dict(exclude_none=True, exclude={"spark"}) diff --git a/tests/fixtures/spark_mock.py b/tests/fixtures/spark_mock.py index b09e7764e..750e21eec 100644 --- a/tests/fixtures/spark_mock.py +++ b/tests/fixtures/spark_mock.py @@ -18,6 +18,23 @@ def spark_no_packages(): return spark +@pytest.fixture( + scope="function", + params=[pytest.param("mock-spark-stopped", marks=[pytest.mark.db_connection, pytest.mark.connection])], +) +def spark_stopped(): + import pyspark + from pyspark.sql import SparkSession + + spark = Mock(spec=SparkSession) + spark.sparkContext = Mock() + spark.sparkContext.appName = "abc" + spark.version = pyspark.__version__ + spark._sc = Mock() + spark._sc._gateway = Mock() + return spark + + @pytest.fixture( scope="function", params=[pytest.param("mock-spark", marks=[pytest.mark.db_connection, pytest.mark.connection])], @@ -29,7 +46,10 @@ def spark_mock(): spark = Mock(spec=SparkSession) spark.sparkContext = Mock() spark.sparkContext.appName = "abc" + spark.version = pyspark.__version__ spark._sc = Mock() spark._sc._gateway = Mock() - spark.version = pyspark.__version__ + spark._jsc = Mock() + spark._jsc.sc = Mock() + spark._jsc.sc().isStopped = Mock(return_value=False) return spark diff --git a/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py b/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py index 4451ef104..42b5582ae 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py @@ -33,6 +33,18 @@ def test_clickhouse_missing_package(spark_no_packages): ) +def test_clickhouse_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + Clickhouse( + host="some_host", + user="user", + database="database", + password="passwd", + spark=spark_stopped, + ) + + def test_clickhouse(spark_mock): conn = Clickhouse(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) diff --git a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py index 474c332d8..276a3c892 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py @@ -83,6 +83,18 @@ def test_greenplum_missing_package(spark_no_packages): ) +def test_greenplum_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + Greenplum( + host="some_host", + user="user", + database="database", + password="passwd", + spark=spark_stopped, + ) + + def test_greenplum(spark_mock): conn = Greenplum(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) diff --git a/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py b/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py index 6469b10c8..01cffd4a8 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py @@ -26,6 +26,12 @@ def test_hive_instance_url(spark_mock): assert hive.instance_url == "some-cluster" +def test_hive_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + Hive(cluster="some-cluster", spark=spark_stopped) + + def test_hive_get_known_clusters_hook(request, spark_mock): # no exception Hive(cluster="unknown", spark=spark_mock) @@ -60,8 +66,6 @@ def normalize_cluster_name(cluster: str) -> str: def test_hive_known_get_current_cluster_hook(request, spark_mock, mocker): - mocker.patch.object(Hive, "_execute_sql", return_value=None) - # no exception Hive(cluster="rnd-prod", spark=spark_mock).check() Hive(cluster="rnd-dwh", spark=spark_mock).check() diff --git a/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py b/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py index 1524d8ad7..404ca57fa 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py @@ -70,6 +70,16 @@ def test_kafka_missing_package(spark_no_packages): ) +def test_kafka_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + Kafka( + cluster="some_cluster", + addresses=["192.168.1.1"], + spark=spark_stopped, + ) + + @pytest.mark.parametrize( "option, value", [ diff --git a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py index eb3f1db23..d53b4d614 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py @@ -79,6 +79,18 @@ def test_mongodb_missing_package(spark_no_packages): ) +def test_mongodb_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + MongoDB( + host="host", + user="user", + password="password", + database="database", + spark=spark_stopped, + ) + + def test_mongodb(spark_mock): conn = MongoDB( host="host", diff --git a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py index e6cd8eb89..7b0328ca9 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py @@ -53,6 +53,18 @@ def test_mssql_missing_package(spark_no_packages): ) +def test_mssql_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + MSSQL( + host="some_host", + user="user", + database="database", + password="passwd", + spark=spark_stopped, + ) + + def test_mssql(spark_mock): conn = MSSQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) diff --git a/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py index 2a33c1523..ed730c418 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py @@ -33,6 +33,18 @@ def test_mysql_missing_package(spark_no_packages): ) +def test_mysql_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + MySQL( + host="some_host", + user="user", + database="database", + password="passwd", + spark=spark_stopped, + ) + + def test_mysql(spark_mock): conn = MySQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) diff --git a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py index bddc14c0f..6a875b8f7 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py @@ -53,6 +53,18 @@ def test_oracle_missing_package(spark_no_packages): ) +def test_oracle_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + Oracle( + host="some_host", + user="user", + sid="sid", + password="passwd", + spark=spark_stopped, + ) + + def test_oracle(spark_mock): conn = Oracle(host="some_host", user="user", sid="sid", password="passwd", spark=spark_mock) diff --git a/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py b/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py index 228c94753..01f85eb08 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py @@ -33,6 +33,18 @@ def test_oracle_missing_package(spark_no_packages): ) +def test_postgres_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + Postgres( + host="some_host", + user="user", + database="database", + password="passwd", + spark=spark_stopped, + ) + + def test_postgres(spark_mock): conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) diff --git a/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py b/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py index 1daf14dc4..dd9ba525d 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py @@ -33,6 +33,18 @@ def test_teradata_missing_package(spark_no_packages): ) +def test_teradata_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + Teradata( + host="some_host", + user="user", + database="database", + password="passwd", + spark=spark_stopped, + ) + + def test_teradata(spark_mock): conn = Teradata(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py index 5e85c16f1..08ca6c1f4 100644 --- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py +++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py @@ -5,14 +5,13 @@ import pytest from onetl.base import BaseFileDFConnection +from onetl.connection import SparkHDFS from onetl.hooks import hook pytestmark = [pytest.mark.hdfs, pytest.mark.file_df_connection, pytest.mark.connection] -def test_spark_hdfs_connection_with_cluster(spark_mock): - from onetl.connection import SparkHDFS - +def test_spark_hdfs_with_cluster(spark_mock): hdfs = SparkHDFS(cluster="rnd-dwh", spark=spark_mock) assert isinstance(hdfs, BaseFileDFConnection) assert hdfs.cluster == "rnd-dwh" @@ -21,9 +20,7 @@ def test_spark_hdfs_connection_with_cluster(spark_mock): assert hdfs.instance_url == "rnd-dwh" -def test_spark_hdfs_connection_with_cluster_and_host(spark_mock): - from onetl.connection import SparkHDFS - +def test_spark_hdfs_with_cluster_and_host(spark_mock): hdfs = SparkHDFS(cluster="rnd-dwh", host="some-host.domain.com", spark=spark_mock) assert isinstance(hdfs, BaseFileDFConnection) assert hdfs.cluster == "rnd-dwh" @@ -31,9 +28,7 @@ def test_spark_hdfs_connection_with_cluster_and_host(spark_mock): assert hdfs.instance_url == "rnd-dwh" -def test_spark_hdfs_connection_with_port(spark_mock): - from onetl.connection import SparkHDFS - +def test_spark_hdfs_with_port(spark_mock): hdfs = SparkHDFS(cluster="rnd-dwh", port=9020, spark=spark_mock) assert isinstance(hdfs, BaseFileDFConnection) assert hdfs.cluster == "rnd-dwh" @@ -41,9 +36,7 @@ def test_spark_hdfs_connection_with_port(spark_mock): assert hdfs.instance_url == "rnd-dwh" -def test_spark_hdfs_connection_without_cluster(spark_mock): - from onetl.connection import SparkHDFS - +def test_spark_hdfs_without_cluster(spark_mock): with pytest.raises(ValueError): SparkHDFS(spark=spark_mock) @@ -51,9 +44,13 @@ def test_spark_hdfs_connection_without_cluster(spark_mock): SparkHDFS(host="some", spark=spark_mock) -def test_spark_hdfs_get_known_clusters_hook(request, spark_mock): - from onetl.connection import SparkHDFS +def test_spark_hdfs_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + SparkHDFS(cluster="rnd-dwh", host="some-host.domain.com", spark=spark_stopped) + +def test_spark_hdfs_get_known_clusters_hook(request, spark_mock): @SparkHDFS.Slots.get_known_clusters.bind @hook def get_known_clusters() -> set[str]: @@ -71,8 +68,6 @@ def get_known_clusters() -> set[str]: def test_spark_hdfs_known_normalize_cluster_name_hook(request, spark_mock): - from onetl.connection import SparkHDFS - @SparkHDFS.Slots.normalize_cluster_name.bind @hook def normalize_cluster_name(cluster: str) -> str: @@ -86,8 +81,6 @@ def normalize_cluster_name(cluster: str) -> str: def test_spark_hdfs_get_cluster_namenodes_hook(request, spark_mock): - from onetl.connection import SparkHDFS - @SparkHDFS.Slots.get_cluster_namenodes.bind @hook def get_cluster_namenodes(cluster: str) -> set[str]: @@ -106,8 +99,6 @@ def get_cluster_namenodes(cluster: str) -> set[str]: def test_spark_hdfs_normalize_namenode_host_hook(request, spark_mock): - from onetl.connection import SparkHDFS - @SparkHDFS.Slots.normalize_namenode_host.bind @hook def normalize_namenode_host(host: str, cluster: str) -> str: @@ -124,8 +115,6 @@ def normalize_namenode_host(host: str, cluster: str) -> str: def test_spark_hdfs_get_ipc_port_hook(request, spark_mock): - from onetl.connection import SparkHDFS - @SparkHDFS.Slots.get_ipc_port.bind @hook def get_ipc_port(cluster: str) -> int | None: @@ -140,8 +129,6 @@ def get_ipc_port(cluster: str) -> int | None: def test_spark_hdfs_known_get_current(request, spark_mock): - from onetl.connection import SparkHDFS - # no hooks bound to SparkHDFS.Slots.get_current_cluster error_msg = re.escape( "SparkHDFS.get_current() can be used only if there are some hooks bound to SparkHDFS.Slots.get_current_cluster", diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py index 8c8c8f377..e98c986cf 100644 --- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py +++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py @@ -23,3 +23,9 @@ def test_spark_local_fs_spark_non_local(spark_mock, master): msg = re.escape("Currently supports only spark.master='local'") with pytest.raises(ValueError, match=msg): SparkLocalFS(spark=spark_mock) + + +def test_spark_local_fs_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + SparkLocalFS(spark=spark_stopped) diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py index b146cebf6..99a20633c 100644 --- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py +++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py @@ -32,7 +32,7 @@ def test_spark_s3_get_packages_spark_2_error(spark_version): @pytest.mark.parametrize("hadoop_version", ["2.7.3", "2.8.0", "2.10.1"]) -def test_spark_s3_connection_with_hadoop_2_error(spark_mock, hadoop_version): +def test_spark_s3_with_hadoop_2_error(spark_mock, hadoop_version): spark_mock._jvm = Mock() spark_mock._jvm.org.apache.hadoop.util.VersionInfo.getVersion = Mock(return_value=hadoop_version) @@ -47,7 +47,7 @@ def test_spark_s3_connection_with_hadoop_2_error(spark_mock, hadoop_version): ) -def test_spark_s3_connection_missing_package(spark_no_packages): +def test_spark_s3_missing_package(spark_no_packages): spark_no_packages._jvm = Mock() spark_no_packages._jvm.org.apache.hadoop.util.VersionInfo.getVersion = Mock(return_value="3.3.6") @@ -63,6 +63,19 @@ def test_spark_s3_connection_missing_package(spark_no_packages): ) +def test_spark_s3_spark_stopped(spark_stopped): + msg = "Spark session is stopped. Please recreate Spark session." + with pytest.raises(ValueError, match=msg): + SparkS3( + host="some_host", + access_key="access_key", + secret_key="some key", + session_token="some token", + bucket="bucket", + spark=spark_stopped, + ) + + @pytest.fixture() def spark_mock_hadoop_3(spark_mock): spark_mock._jvm = Mock() @@ -70,7 +83,7 @@ def spark_mock_hadoop_3(spark_mock): return spark_mock -def test_spark_s3_connection(spark_mock_hadoop_3): +def test_spark_s3(spark_mock_hadoop_3): s3 = SparkS3( host="some_host", access_key="access key", @@ -91,7 +104,7 @@ def test_spark_s3_connection(spark_mock_hadoop_3): assert "some key" not in repr(s3) -def test_spark_s3_connection_with_protocol_https(spark_mock_hadoop_3): +def test_spark_s3_with_protocol_https(spark_mock_hadoop_3): s3 = SparkS3( host="some_host", access_key="access_key", @@ -106,7 +119,7 @@ def test_spark_s3_connection_with_protocol_https(spark_mock_hadoop_3): assert s3.instance_url == "s3://some_host:443" -def test_spark_s3_connection_with_protocol_http(spark_mock_hadoop_3): +def test_spark_s3_with_protocol_http(spark_mock_hadoop_3): s3 = SparkS3( host="some_host", access_key="access_key", @@ -122,7 +135,7 @@ def test_spark_s3_connection_with_protocol_http(spark_mock_hadoop_3): @pytest.mark.parametrize("protocol", ["http", "https"]) -def test_spark_s3_connection_with_port(spark_mock_hadoop_3, protocol): +def test_spark_s3_with_port(spark_mock_hadoop_3, protocol): s3 = SparkS3( host="some_host", port=9000,