Skip to content

Commit

Permalink
[DOP-6758] Fix Hive.check() behavior when Hive Metastore is not avail…
Browse files Browse the repository at this point in the history
…able
  • Loading branch information
dolfinus committed Oct 6, 2023
1 parent 32c37ed commit fe52691
Show file tree
Hide file tree
Showing 20 changed files with 205 additions and 37 deletions.
1 change: 1 addition & 0 deletions docs/changelog/next_release/164.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix ``Hive.check()`` behavior when Hive Metastore is not available.
1 change: 1 addition & 0 deletions docs/changelog/next_release/164.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add check to all DB and FileDF connections that Spark session is alive.
15 changes: 14 additions & 1 deletion onetl/connection/db_connection/db_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from logging import getLogger
from typing import TYPE_CHECKING

from pydantic import Field
from pydantic import Field, validator

from onetl._util.spark import try_import_pyspark
from onetl.base import BaseDBConnection
Expand Down Expand Up @@ -48,6 +48,19 @@ def _forward_refs(cls) -> dict[str, type]:
refs["SparkSession"] = SparkSession
return refs

@validator("spark")
def _check_spark_session_alive(cls, spark):
# https://stackoverflow.com/a/36044685
msg = "Spark session is stopped. Please recreate Spark session."
try:
if not spark._jsc.sc().isStopped():
return spark
except Exception as e:
# None has no attribute "something"
raise ValueError(msg) from e

raise ValueError(msg)

Check warning on line 62 in onetl/connection/db_connection/db_connection/connection.py

View check run for this annotation

Codecov / codecov/patch

onetl/connection/db_connection/db_connection/connection.py#L62

Added line #L62 was not covered by tests

def _log_parameters(self):
log.info("|%s| Using connection parameters:", self.__class__.__name__)
parameters = self.dict(exclude_none=True, exclude={"spark"})
Expand Down
4 changes: 2 additions & 2 deletions onetl/connection/db_connection/hive/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class Hive(DBConnection):
# TODO: remove in v1.0.0
slots = HiveSlots

_CHECK_QUERY: ClassVar[str] = "SELECT 1"
_CHECK_QUERY: ClassVar[str] = "SHOW DATABASES"

@slot
@classmethod
Expand Down Expand Up @@ -207,7 +207,7 @@ def check(self):
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
self._execute_sql(self._CHECK_QUERY)
self._execute_sql(self._CHECK_QUERY).limit(1).collect()
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down
1 change: 1 addition & 0 deletions onetl/connection/db_connection/mongodb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def write_df_to_target(
)

if self._collection_exists(target):
# MongoDB connector does not support mode=ignore and mode=error
if write_options.if_exists == MongoDBCollectionExistBehavior.ERROR:
raise ValueError("Operation stopped due to MongoDB.WriteOptions(if_exists='error')")
elif write_options.if_exists == MongoDBCollectionExistBehavior.IGNORE:
Expand Down
15 changes: 14 additions & 1 deletion onetl/connection/file_df_connection/spark_file_df_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from logging import getLogger
from typing import TYPE_CHECKING

from pydantic import Field
from pydantic import Field, validator

from onetl._util.hadoop import get_hadoop_config
from onetl._util.spark import try_import_pyspark
Expand Down Expand Up @@ -182,6 +182,19 @@ def _forward_refs(cls) -> dict[str, type]:
refs["SparkSession"] = SparkSession
return refs

@validator("spark")
def _check_spark_session_alive(cls, spark):
# https://stackoverflow.com/a/36044685
msg = "Spark session is stopped. Please recreate Spark session."
try:
if not spark._jsc.sc().isStopped():
return spark
except Exception as e:
# None has no attribute "something"
raise ValueError(msg) from e

raise ValueError(msg)

Check warning on line 196 in onetl/connection/file_df_connection/spark_file_df_connection.py

View check run for this annotation

Codecov / codecov/patch

onetl/connection/file_df_connection/spark_file_df_connection.py#L196

Added line #L196 was not covered by tests

def _log_parameters(self):
log.info("|%s| Using connection parameters:", self.__class__.__name__)
parameters = self.dict(exclude_none=True, exclude={"spark"})
Expand Down
25 changes: 24 additions & 1 deletion tests/fixtures/spark_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@
import pytest


@pytest.fixture(
scope="function",
params=[pytest.param("mock-spark-stopped", marks=[pytest.mark.db_connection, pytest.mark.connection])],
)
def spark_stopped():
import pyspark
from pyspark.sql import SparkSession

spark = Mock(spec=SparkSession)
spark.sparkContext = Mock()
spark.sparkContext.appName = "abc"
spark.version = pyspark.__version__
spark._sc = Mock()
spark._sc._gateway = Mock()
return spark


@pytest.fixture(
scope="function",
params=[pytest.param("mock-spark-no-packages", marks=[pytest.mark.db_connection, pytest.mark.connection])],
Expand All @@ -15,6 +32,9 @@ def spark_no_packages():
spark.sparkContext = Mock()
spark.sparkContext.appName = "abc"
spark.version = pyspark.__version__
spark._jsc = Mock()
spark._jsc.sc = Mock()
spark._jsc.sc().isStopped = Mock(return_value=False)
return spark


Expand All @@ -29,7 +49,10 @@ def spark_mock():
spark = Mock(spec=SparkSession)
spark.sparkContext = Mock()
spark.sparkContext.appName = "abc"
spark.version = pyspark.__version__
spark._sc = Mock()
spark._sc._gateway = Mock()
spark.version = pyspark.__version__
spark._jsc = Mock()
spark._jsc.sc = Mock()
spark._jsc.sc().isStopped = Mock(return_value=False)
return spark
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def test_clickhouse_missing_package(spark_no_packages):
)


def test_clickhouse_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Clickhouse(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_clickhouse(spark_mock):
conn = Clickhouse(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ def test_greenplum_missing_package(spark_no_packages):
)


def test_greenplum_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Greenplum(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_greenplum(spark_mock):
conn = Greenplum(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
8 changes: 6 additions & 2 deletions tests/tests_unit/tests_db_connection_unit/test_hive_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def test_hive_instance_url(spark_mock):
assert hive.instance_url == "some-cluster"


def test_hive_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Hive(cluster="some-cluster", spark=spark_stopped)


def test_hive_get_known_clusters_hook(request, spark_mock):
# no exception
Hive(cluster="unknown", spark=spark_mock)
Expand Down Expand Up @@ -60,8 +66,6 @@ def normalize_cluster_name(cluster: str) -> str:


def test_hive_known_get_current_cluster_hook(request, spark_mock, mocker):
mocker.patch.object(Hive, "_execute_sql", return_value=None)

# no exception
Hive(cluster="rnd-prod", spark=spark_mock).check()
Hive(cluster="rnd-dwh", spark=spark_mock).check()
Expand Down
10 changes: 10 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ def test_kafka_missing_package(spark_no_packages):
)


def test_kafka_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Kafka(
cluster="some_cluster",
addresses=["192.168.1.1"],
spark=spark_stopped,
)


@pytest.mark.parametrize(
"option, value",
[
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ def test_mongodb_missing_package(spark_no_packages):
)


def test_mongodb_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
MongoDB(
host="host",
user="user",
password="password",
database="database",
spark=spark_stopped,
)


def test_mongodb(spark_mock):
conn = MongoDB(
host="host",
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def test_mssql_missing_package(spark_no_packages):
)


def test_mssql_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
MSSQL(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_mssql(spark_mock):
conn = MSSQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def test_mysql_missing_package(spark_no_packages):
)


def test_mysql_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
MySQL(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_mysql(spark_mock):
conn = MySQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def test_oracle_missing_package(spark_no_packages):
)


def test_oracle_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Oracle(
host="some_host",
user="user",
sid="sid",
password="passwd",
spark=spark_stopped,
)


def test_oracle(spark_mock):
conn = Oracle(host="some_host", user="user", sid="sid", password="passwd", spark=spark_mock)

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def test_oracle_missing_package(spark_no_packages):
)


def test_postgres_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Postgres(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_postgres(spark_mock):
conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def test_teradata_missing_package(spark_no_packages):
)


def test_teradata_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Teradata(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_teradata(spark_mock):
conn = Teradata(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
Loading

0 comments on commit fe52691

Please sign in to comment.