diff --git a/docs/changelog/next_release/164.bugfix.rst b/docs/changelog/next_release/164.bugfix.rst new file mode 100644 index 000000000..e9c591100 --- /dev/null +++ b/docs/changelog/next_release/164.bugfix.rst @@ -0,0 +1 @@ +Fix ``Hive.check()`` behavior when Hive Metastore is not available. diff --git a/docs/changelog/next_release/164.improvement.rst b/docs/changelog/next_release/164.improvement.rst new file mode 100644 index 000000000..09799b8d2 --- /dev/null +++ b/docs/changelog/next_release/164.improvement.rst @@ -0,0 +1 @@ +Add check to all DB and FileDF connections that Spark session is alive. diff --git a/onetl/connection/db_connection/db_connection/connection.py b/onetl/connection/db_connection/db_connection/connection.py index 315f5b17c..138f8f22a 100644 --- a/onetl/connection/db_connection/db_connection/connection.py +++ b/onetl/connection/db_connection/db_connection/connection.py @@ -17,7 +17,7 @@ from logging import getLogger from typing import TYPE_CHECKING -from pydantic import Field +from pydantic import Field, validator from onetl._util.spark import try_import_pyspark from onetl.base import BaseDBConnection @@ -48,6 +48,16 @@ def _forward_refs(cls) -> dict[str, type]: refs["SparkSession"] = SparkSession return refs + @validator("spark") + def _check_spark_session_alive(cls, spark): + try: + spark.sql("SELECT 1") + except Exception as e: + msg = "Spark session is stopped. Please recreate Spark session." + raise ValueError(msg) from e + + return spark + def _log_parameters(self): log.info("|%s| Using connection parameters:", self.__class__.__name__) parameters = self.dict(exclude_none=True, exclude={"spark"}) diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py index 740d09b44..37d384cd1 100644 --- a/onetl/connection/db_connection/hive/connection.py +++ b/onetl/connection/db_connection/hive/connection.py @@ -146,7 +146,7 @@ class Hive(DBConnection): # TODO: remove in v1.0.0 slots = HiveSlots - _CHECK_QUERY: ClassVar[str] = "SELECT 1" + _CHECK_QUERY: ClassVar[str] = "SHOW DATABASES" @slot @classmethod diff --git a/onetl/connection/db_connection/mongodb/connection.py b/onetl/connection/db_connection/mongodb/connection.py index 860f7b215..280596d5d 100644 --- a/onetl/connection/db_connection/mongodb/connection.py +++ b/onetl/connection/db_connection/mongodb/connection.py @@ -507,6 +507,7 @@ def write_df_to_target( ) if self._collection_exists(target): + # MongoDB connector does not support mode=ignore and mode=error if write_options.if_exists == MongoDBCollectionExistBehavior.ERROR: raise ValueError("Operation stopped due to MongoDB.WriteOptions(if_exists='error')") elif write_options.if_exists == MongoDBCollectionExistBehavior.IGNORE: diff --git a/onetl/connection/file_df_connection/spark_file_df_connection.py b/onetl/connection/file_df_connection/spark_file_df_connection.py index 7c1994182..e56a6dfce 100644 --- a/onetl/connection/file_df_connection/spark_file_df_connection.py +++ b/onetl/connection/file_df_connection/spark_file_df_connection.py @@ -19,7 +19,7 @@ from logging import getLogger from typing import TYPE_CHECKING -from pydantic import Field +from pydantic import Field, validator from onetl._util.hadoop import get_hadoop_config from onetl._util.spark import try_import_pyspark @@ -182,6 +182,16 @@ def _forward_refs(cls) -> dict[str, type]: refs["SparkSession"] = SparkSession return refs + @validator("spark") + def _check_spark_session_alive(cls, spark): + try: + spark.sql("SELECT 1") + except Exception as e: + msg = "Spark session is stopped. Please recreate Spark session." + raise ValueError(msg) from e + + return spark + def _log_parameters(self): log.info("|%s| Using connection parameters:", self.__class__.__name__) parameters = self.dict(exclude_none=True, exclude={"spark"})