Skip to content

Commit

Permalink
[DOP-18570] Implement SparkMetricsRecorder
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Aug 7, 2024
1 parent c4a9cb8 commit 5c068cd
Show file tree
Hide file tree
Showing 32 changed files with 2,343 additions and 9 deletions.
34 changes: 34 additions & 0 deletions onetl/_util/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from typing import TYPE_CHECKING

from onetl._util.spark import get_spark_version
from onetl._util.version import Version

if TYPE_CHECKING:
from py4j.java_gateway import JavaGateway
from pyspark.sql import SparkSession
Expand All @@ -24,3 +27,34 @@ def try_import_java_class(spark_session: SparkSession, name: str):
klass = getattr(gateway.jvm, name)
gateway.help(klass, display=False)
return klass


def start_callback_server(spark_session: SparkSession):
"""
Start Py4J callback server. Important to receive Java events on Python side,
e.g. in Spark Listener implementations.
"""
gateway = get_java_gateway(spark_session)
if get_spark_version(spark_session) >= Version("2.4"):
from pyspark.java_gateway import ensure_callback_server_started

ensure_callback_server_started(gateway)
return

# PySpark 2.3
if "_callback_server" not in gateway.__dict__ or gateway._callback_server is None:
from py4j.java_gateway import JavaObject

gateway.callback_server_parameters.eager_load = True
gateway.callback_server_parameters.daemonize = True
gateway.callback_server_parameters.daemonize_connections = True
gateway.callback_server_parameters.port = 0
gateway.start_callback_server(gateway.callback_server_parameters)
cbport = gateway._callback_server.server_socket.getsockname()[1]
gateway._callback_server.port = cbport
# gateway with real port
gateway._python_proxy_port = gateway._callback_server.port
# get the GatewayServer object in JVM by ID
java_gateway = JavaObject("GATEWAY_SERVER", gateway._gateway_client)
# update the port of CallbackClient with real port
java_gateway.resetCallbackClient(java_gateway.getCallbackClient().getAddress(), gateway._python_proxy_port)
7 changes: 7 additions & 0 deletions onetl/_util/scala.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,10 @@ def get_default_scala_version(spark_version: Version) -> Version:
if spark_version.major < 3:
return Version("2.11")
return Version("2.12")


def scala_seq_to_python_list(seq) -> list:
result = []
for i in range(seq.size()):
result.append(seq.apply(i))
return result
12 changes: 11 additions & 1 deletion onetl/_util/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import SecretStr # type: ignore[no-redef, assignment]

if TYPE_CHECKING:
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.conf import RuntimeConfig


Expand Down Expand Up @@ -136,6 +136,16 @@ def get_spark_version(spark_session: SparkSession) -> Version:
return Version(spark_session.version)


def estimate_dataframe_size(spark_session: SparkSession, df: DataFrame) -> int:
"""
Estimate in-memory DataFrame size in bytes.
Using Spark's `SizeEstimator <https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/util/SizeEstimator.html>`_.
"""
size_estimator = spark_session._jvm.org.apache.spark.util.SizeEstimator # type: ignore[union-attr]
return size_estimator.estimate(df._jdf)


def get_executor_total_cores(spark_session: SparkSession, include_driver: bool = False) -> tuple[int | float, dict]:
"""
Calculate maximum number of cores which can be used by Spark on all executors.
Expand Down
16 changes: 16 additions & 0 deletions onetl/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0
from onetl.metrics.command import SparkCommandMetrics
from onetl.metrics.driver import SparkDriverMetrics
from onetl.metrics.executor import SparkExecutorMetrics
from onetl.metrics.input import SparkInputMetrics
from onetl.metrics.output import SparkOutputMetrics
from onetl.metrics.recorder import SparkMetricsRecorder

__all__ = [
"SparkCommandMetrics",
"SparkMetricsRecorder",
"SparkExecutorMetrics",
"SparkInputMetrics",
"SparkOutputMetrics",
]
141 changes: 141 additions & 0 deletions onetl/metrics/_extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import re
from datetime import timedelta
from typing import Any

try:
from pydantic.v1 import ByteSize
except (ImportError, AttributeError):
from pydantic import ByteSize # type: ignore[no-redef, assignment]

from onetl.metrics._listener.execution import (
SparkListenerExecution,
SparkSQLMetricNames,
)
from onetl.metrics.command import SparkCommandMetrics
from onetl.metrics.driver import SparkDriverMetrics
from onetl.metrics.executor import SparkExecutorMetrics
from onetl.metrics.input import SparkInputMetrics
from onetl.metrics.output import SparkOutputMetrics

NON_DIGIT = re.compile(r"[^\d.]")


def _get_int(data: dict[SparkSQLMetricNames, list[str]], key: Any) -> int | None:
if key not in data:
return None

items = data[key]
if not items:
return None

return int(items[0])


def _get_bytes(data: dict[SparkSQLMetricNames, list[str]], key: Any) -> int | None:
if key not in data:
return None

items = data[key]
if not items:
return None

return int(ByteSize.validate(items[0]))


def _get_time(data: dict[SparkSQLMetricNames, list[str]], key: Any) -> timedelta | None: # noqa: Found
if key not in data:
return None

items = data[key]
if not items:
return None

str_value = items[0]
digits = NON_DIGIT.sub("", str_value)
# reverse of msDurationToString:
# https://github.com/apache/spark/blob/v3.5.1/core/src/main/scala/org/apache/spark/util/Utils.scala#L1243-L1257
if str_value.endswith(" ms"):
return timedelta(milliseconds=float(digits))
if str_value.endswith(" s"):
return timedelta(seconds=float(digits))
if str_value.endswith(" m"):
return timedelta(minutes=float(digits))
return timedelta(hours=float(digits))


def extract_metrics_from_execution(execution: SparkListenerExecution) -> SparkCommandMetrics:
input_read_bytes: int = 0
input_read_rows: int = 0
output_bytes: int = 0
output_rows: int = 0

run_time_milliseconds: int = 0
cpu_time_nanoseconds: int = 0
peak_memory_bytes: int = 0
memory_spilled_bytes: int = 0
disk_spilled_bytes: int = 0
result_size_bytes: int = 0

# some metrics are per-stage, and have to be summed, others are per-execution
for job in execution.jobs:
for stage in job.stages:
input_read_bytes += stage.metrics.input_metrics.bytes_read
input_read_rows += stage.metrics.input_metrics.records_read
output_bytes += stage.metrics.output_metrics.bytes_written
output_rows += stage.metrics.output_metrics.records_written

run_time_milliseconds += stage.metrics.executor_run_time_milliseconds
cpu_time_nanoseconds += stage.metrics.executor_cpu_time_nanoseconds
peak_memory_bytes = max(peak_memory_bytes, stage.metrics.peak_execution_memory_bytes)
memory_spilled_bytes += stage.metrics.memory_spilled_bytes
disk_spilled_bytes += stage.metrics.disk_spilled_bytes
result_size_bytes += stage.metrics.result_size_bytes

# https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L467-L473
input_file_count = (
_get_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_FILES_READ)
or _get_int(execution.metrics, SparkSQLMetricNames.STATIC_NUMBER_OF_FILES_READ)
or 0
)
input_raw_file_bytes = (
_get_bytes(execution.metrics, SparkSQLMetricNames.SIZE_OF_FILES_READ)
or _get_bytes(execution.metrics, SparkSQLMetricNames.STATIC_SIZE_OF_FILES_READ)
or 0
)
input_read_partitions = _get_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_PARTITIONS_READ) or 0

input_query_time = _get_time(execution.metrics, SparkSQLMetricNames.JDBC_QUERY_EXECUTION_TIME) or timedelta(0)

output_files = _get_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_WRITTEN_FILES) or 0
output_dynamic_partitions = _get_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_DYNAMIC_PART) or 0

return SparkCommandMetrics(
input=SparkInputMetrics(
read_rows=input_read_rows,
read_files=input_file_count,
read_bytes=input_read_bytes,
raw_file_bytes=input_raw_file_bytes,
read_partitions=input_read_partitions,
query_time=input_query_time,
),
output=SparkOutputMetrics(
written_rows=output_rows,
written_bytes=output_bytes,
created_files=output_files,
created_partitions=output_dynamic_partitions,
),
driver=SparkDriverMetrics(
in_memory_bytes=result_size_bytes,
),
executor=SparkExecutorMetrics(
total_run_time=timedelta(milliseconds=run_time_milliseconds),
total_cpu_time=timedelta(microseconds=cpu_time_nanoseconds / 1000),
peak_memory_bytes=peak_memory_bytes,
memory_spilled_bytes=memory_spilled_bytes,
disk_spilled_bytes=disk_spilled_bytes,
),
)
29 changes: 29 additions & 0 deletions onetl/metrics/_listener/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0
from onetl.metrics._listener.execution import (
SparkListenerExecution,
SparkListenerExecutionStatus,
SparkSQLMetricNames,
)
from onetl.metrics._listener.job import SparkListenerJob, SparkListenerJobStatus
from onetl.metrics._listener.listener import SparkMetricsListener
from onetl.metrics._listener.stage import SparkListenerStage, SparkListenerStageStatus
from onetl.metrics._listener.task import (
SparkListenerTask,
SparkListenerTaskMetrics,
SparkListenerTaskStatus,
)

__all__ = [
"SparkListenerTask",
"SparkListenerTaskStatus",
"SparkListenerTaskMetrics",
"SparkListenerStage",
"SparkListenerStageStatus",
"SparkListenerJob",
"SparkListenerJobStatus",
"SparkListenerExecution",
"SparkListenerExecutionStatus",
"SparkSQLMetricNames",
"SparkMetricsListener",
]
Loading

0 comments on commit 5c068cd

Please sign in to comment.