Skip to content

Commit

Permalink
[DOP-18571] Collect and log Spark metrics in various method calls
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Aug 8, 2024
1 parent 2ce8ec9 commit 018fbc2
Show file tree
Hide file tree
Showing 13 changed files with 209 additions and 89 deletions.
1 change: 1 addition & 0 deletions docs/changelog/next_release/303.feature.1.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Log estimated size of in-memory dataframe created by ``JDBC.fetch`` and ``JDBC.execute`` methods.
10 changes: 10 additions & 0 deletions docs/changelog/next_release/303.feature.2.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Collect Spark execution metrics in following methods, and log then in DEBUG mode:
* ``DBWriter.run()``
* ``FileDFWriter.run()``
* ``Hive.sql()``
* ``Hive.execute()``

This is implemented using custom ``SparkListener`` which wraps the entire method call, and
then report collected metrics. But these metrics sometimes may be missing due to Spark architecture,
so they are not reliable source of information. That's why logs are printed only in DEBUG mode, and
are not returned as method call result.
17 changes: 16 additions & 1 deletion onetl/_util/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import SecretStr # type: ignore[no-redef, assignment]

if TYPE_CHECKING:
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.conf import RuntimeConfig


Expand Down Expand Up @@ -136,6 +136,21 @@ def get_spark_version(spark_session: SparkSession) -> Version:
return Version(spark_session.version)


def estimate_dataframe_size(spark_session: SparkSession, df: DataFrame) -> int:
"""
Estimate in-memory DataFrame size in bytes. If cannot be estimated, return 0.
Using Spark's `SizeEstimator <https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/util/SizeEstimator.html>`_.
"""
try:
size_estimator = spark_session._jvm.org.apache.spark.util.SizeEstimator # type: ignore[union-attr]
return size_estimator.estimate(df._jdf)
except Exception:
# SizeEstimator uses Java reflection which may behave differently in different Java versions,
# and also may be prohibited.
return 0


def get_executor_total_cores(spark_session: SparkSession, include_driver: bool = False) -> tuple[int | float, dict]:
"""
Calculate maximum number of cores which can be used by Spark on all executors.
Expand Down
3 changes: 2 additions & 1 deletion onetl/base/base_db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

if TYPE_CHECKING:
from etl_entities.hwm import HWM
from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import StructField, StructType


Expand Down Expand Up @@ -106,6 +106,7 @@ class BaseDBConnection(BaseConnection):
Implements generic methods for reading and writing dataframe from/to database-like source
"""

spark: SparkSession
Dialect = BaseDBDialect

@property
Expand Down
4 changes: 3 additions & 1 deletion onetl/base/base_file_df_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from onetl.base.pure_path_protocol import PurePathProtocol

if TYPE_CHECKING:
from pyspark.sql import DataFrame, DataFrameReader, DataFrameWriter
from pyspark.sql import DataFrame, DataFrameReader, DataFrameWriter, SparkSession
from pyspark.sql.types import StructType


Expand Down Expand Up @@ -72,6 +72,8 @@ class BaseFileDFConnection(BaseConnection):
.. versionadded:: 0.9.0
"""

spark: SparkSession

@abstractmethod
def check_if_format_supported(
self,
Expand Down
49 changes: 45 additions & 4 deletions onetl/connection/db_connection/hive/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
except (ImportError, AttributeError):
from pydantic import validator # type: ignore[no-redef, assignment]

from onetl._metrics.recorder import SparkMetricsRecorder
from onetl._util.spark import inject_spark_param
from onetl._util.sql import clear_statement
from onetl.connection.db_connection.db_connection import DBConnection
Expand Down Expand Up @@ -210,8 +211,29 @@ def sql(
log.info("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, query)

df = self._execute_sql(query)
log.info("|Spark| DataFrame successfully created from SQL statement")
with SparkMetricsRecorder(self.spark) as recorder:
try:
df = self._execute_sql(query)
except Exception:
log.error("|%s| Query failed", self.__class__.__name__)

metrics = recorder.metrics()
if not metrics.is_empty and log.isEnabledFor(logging.DEBUG):
# as SparkListener results are not guaranteed to be received in time,
# some metrics may be missing. To avoid confusion, log only in debug, and with a notice
log.info("|%s| Recorded metrics (some values may be missing!):", self.__class__.__name__)
log_lines(log, str(metrics), level=logging.DEBUG)
raise

log.info("|Spark| DataFrame successfully created from SQL statement")

metrics = recorder.metrics()
if not metrics.is_empty and log.isEnabledFor(logging.DEBUG):
# as SparkListener results are not guaranteed to be received in time,
# some metrics may be missing. To avoid confusion, log only in debug, and with a notice
log.info("|%s| Recorded metrics (some values may be missing!):", self.__class__.__name__)
log_lines(log, str(metrics), level=logging.DEBUG)

return df

@slot
Expand All @@ -236,8 +258,27 @@ def execute(
log.info("|%s| Executing statement:", self.__class__.__name__)
log_lines(log, statement)

self._execute_sql(statement).collect()
log.info("|%s| Call succeeded", self.__class__.__name__)
with SparkMetricsRecorder(self.spark) as recorder:
try:
self._execute_sql(statement).collect()
except Exception:
log.error("|%s| Execution failed", self.__class__.__name__)
metrics = recorder.metrics()
if not metrics.is_empty and log.isEnabledFor(logging.DEBUG):
# as SparkListener results are not guaranteed to be received in time,
# some metrics may be missing. To avoid confusion, log only in debug, and with a notice
log.info("|%s| Recorded metrics (some values may be missing!):", self.__class__.__name__)
log_lines(log, str(metrics), level=logging.DEBUG)
raise

log.info("|%s| Execution succeeded", self.__class__.__name__)

metrics = recorder.metrics()
if not metrics.is_empty and log.isEnabledFor(logging.DEBUG):
# as SparkListener results are not guaranteed to be received in time,
# some metrics may be missing. To avoid confusion, log only in debug, and with a notice
log.info("|%s| Recorded metrics (some values may be missing!):", self.__class__.__name__)
log_lines(log, str(metrics), level=logging.DEBUG)

@slot
def write_df_to_target(
Expand Down
8 changes: 6 additions & 2 deletions onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,13 @@ def sql(
log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__)
log_lines(log, query)

df = self._query_on_executor(query, self.SQLOptions.parse(options))
try:
df = self._query_on_executor(query, self.SQLOptions.parse(options))
except Exception:
log.error("|%s| Query failed!", self.__class__.__name__)
raise

log.info("|Spark| DataFrame successfully created from SQL statement ")
log.info("|Spark| DataFrame successfully created from SQL statement")
return df

@slot
Expand Down
69 changes: 42 additions & 27 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
from enum import Enum, auto
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, TypeVar

from onetl.impl.generic_options import GenericOptions

try:
from pydantic.v1 import Field, PrivateAttr, SecretStr, validator
except (ImportError, AttributeError):
from pydantic import Field, PrivateAttr, SecretStr, validator # type: ignore[no-redef, assignment]

from onetl._metrics.command import SparkCommandMetrics
from onetl._util.java import get_java_gateway, try_import_java_class
from onetl._util.spark import get_spark_version, stringify
from onetl._util.spark import estimate_dataframe_size, get_spark_version, stringify
from onetl._util.sql import clear_statement
from onetl._util.version import Version
from onetl.connection.db_connection.jdbc_mixin.options import (
Expand All @@ -29,7 +28,7 @@
)
from onetl.exception import MISSING_JVM_CLASS_MSG
from onetl.hooks import slot, support_hooks
from onetl.impl import FrozenModel
from onetl.impl import FrozenModel, GenericOptions
from onetl.log import log_lines

if TYPE_CHECKING:
Expand Down Expand Up @@ -204,20 +203,27 @@ def fetch(
log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log_lines(log, query)

df = self._query_on_driver(
query,
(
self.FetchOptions.parse(options.dict()) # type: ignore
if isinstance(options, JDBCMixinOptions)
else self.FetchOptions.parse(options)
),
call_options = (
self.FetchOptions.parse(options.dict()) # type: ignore
if isinstance(options, JDBCMixinOptions)
else self.FetchOptions.parse(options)
)

log.info(
"|%s| Query succeeded, resulting in-memory dataframe contains %d rows",
self.__class__.__name__,
df.count(),
)
try:
df = self._query_on_driver(query, call_options)
except Exception:
log.error("|%s| Query failed!", self.__class__.__name__)
raise

log.info("|%s| Query succeeded, created in-memory dataframe.", self.__class__.__name__)

# as we don't actually use Spark for this method, SparkMetricsRecorder is useless.
# Just create metrics by hand, and fill them up using information based on dataframe content.
metrics = SparkCommandMetrics()
metrics.input.read_rows = df.count()
metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df)
log.info("|%s| Recorded metrics:", self.__class__.__name__)
log_lines(log, str(metrics))
return df

@slot
Expand Down Expand Up @@ -273,17 +279,26 @@ def execute(
if isinstance(options, JDBCMixinOptions)
else self.ExecuteOptions.parse(options)
)
df = self._call_on_driver(statement, call_options)

if df is not None:
rows_count = df.count()
log.info(
"|%s| Execution succeeded, resulting in-memory dataframe contains %d rows",
self.__class__.__name__,
rows_count,
)
else:
log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__)

try:
df = self._call_on_driver(statement, call_options)
except Exception:
log.error("|%s| Execution failed!", self.__class__.__name__)
raise

if not df:
log.info("|%s| Execution succeeded, nothing returned.", self.__class__.__name__)
return None

log.info("|%s| Execution succeeded, created in-memory dataframe.", self.__class__.__name__)
# as we don't actually use Spark for this method, SparkMetricsRecorder is useless.
# Just create metrics by hand, and fill them up using information based on dataframe content.
metrics = SparkCommandMetrics()
metrics.input.read_rows = df.count()
metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df)

log.info("|%s| Recorded metrics:", self.__class__.__name__)
log_lines(log, str(metrics))
return df

@validator("spark")
Expand Down
39 changes: 9 additions & 30 deletions onetl/connection/db_connection/oracle/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@
from etl_entities.instance import Host

from onetl._util.classproperty import classproperty
from onetl._util.sql import clear_statement
from onetl._util.version import Version
from onetl.connection.db_connection.jdbc_connection import JDBCConnection
from onetl.connection.db_connection.jdbc_connection.options import JDBCReadOptions
from onetl.connection.db_connection.jdbc_mixin.options import (
JDBCExecuteOptions,
JDBCFetchOptions,
JDBCOptions,
)
from onetl.connection.db_connection.oracle.dialect import OracleDialect
from onetl.connection.db_connection.oracle.options import (
Expand All @@ -43,8 +41,6 @@
from onetl.log import BASE_LOG_INDENT, log_lines

# do not import PySpark here, as we allow user to use `Oracle.get_packages()` for creating Spark session


if TYPE_CHECKING:
from pyspark.sql import DataFrame

Expand Down Expand Up @@ -290,32 +286,6 @@ def get_min_max_values(
max_value = int(max_value)
return min_value, max_value

@slot
def execute(
self,
statement: str,
options: JDBCOptions | JDBCExecuteOptions | dict | None = None, # noqa: WPS437
) -> DataFrame | None:
statement = clear_statement(statement)

log.info("|%s| Executing statement (on driver):", self.__class__.__name__)
log_lines(log, statement)

call_options = self.ExecuteOptions.parse(options)
df = self._call_on_driver(statement, call_options)
self._handle_compile_errors(statement.strip(), call_options)

if df is not None:
rows_count = df.count()
log.info(
"|%s| Execution succeeded, resulting in-memory dataframe contains %d rows",
self.__class__.__name__,
rows_count,
)
else:
log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__)
return df

@root_validator
def _only_one_of_sid_or_service_name(cls, values):
sid = values.get("sid")
Expand All @@ -329,6 +299,15 @@ def _only_one_of_sid_or_service_name(cls, values):

return values

def _call_on_driver(
self,
query: str,
options: JDBCExecuteOptions,
) -> DataFrame | None:
result = super()._call_on_driver(query, options)
self._handle_compile_errors(query.strip(), options)
return result

def _parse_create_statement(self, statement: str) -> tuple[str, str, str] | None:
"""
Parses ``CREATE ... type_name [schema.]object_name ...`` statement
Expand Down
Loading

0 comments on commit 018fbc2

Please sign in to comment.