Skip to content

Commit

Permalink
DOP-18743] Set default jobDescription
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Aug 8, 2024
1 parent c4a9cb8 commit e56136d
Show file tree
Hide file tree
Showing 53 changed files with 549 additions and 386 deletions.
3 changes: 3 additions & 0 deletions docs/changelog/next_release/304.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Change connection URL used for generating HWM names of S3 and Samba sources:
* ``smb://host:port`` -> ``smb://host:port/share``
* ``s3://host:port`` -> ``s3://host:port/bucket``
6 changes: 6 additions & 0 deletions docs/changelog/next_release/304.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Generate default ``jobDescription`` based on currently executed method. Examples:
* ``DBWriter() -> Postgres[host:5432/database]``
* ``MongoDB[localhost:27017/admin] -> DBReader.run()``
* ``Hive[cluster].execute()``

If user already set custom ``jobDescription``, it will left intact.
4 changes: 2 additions & 2 deletions onetl/_util/hadoop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_hadoop_version(spark_session: SparkSession) -> Version:
"""
Get version of Hadoop libraries embedded to Spark
"""
jvm = spark_session._jvm # noqa: WPS437
jvm = spark_session._jvm # noqa: WPS437 # type: ignore[attr-defined]
version_info = jvm.org.apache.hadoop.util.VersionInfo # type: ignore[union-attr]
hadoop_version: str = version_info.getVersion()
return Version(hadoop_version)
Expand All @@ -24,4 +24,4 @@ def get_hadoop_config(spark_session: SparkSession):
"""
Get ``org.apache.hadoop.conf.Configuration`` object
"""
return spark_session.sparkContext._jsc.hadoopConfiguration()
return spark_session.sparkContext._jsc.hadoopConfiguration() # type: ignore[attr-defined]
2 changes: 1 addition & 1 deletion onetl/_util/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_java_gateway(spark_session: SparkSession) -> JavaGateway:
"""
Get py4j Java gateway object
"""
return spark_session._sc._gateway # noqa: WPS437 # type: ignore
return spark_session._sc._gateway # noqa: WPS437 # type: ignore[attr-defined]


def try_import_java_class(spark_session: SparkSession, name: str):
Expand Down
23 changes: 23 additions & 0 deletions onetl/_util/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from pyspark.sql import SparkSession
from pyspark.sql.conf import RuntimeConfig

SPARK_JOB_DESCRIPTION_PROPERTY = "spark.job.description"
SPARK_JOB_GROUP_PROPERTY = "spark.jobGroup.id"


def stringify(value: Any, quote: bool = False) -> Any: # noqa: WPS212
"""
Expand Down Expand Up @@ -185,3 +188,23 @@ def get_executor_total_cores(spark_session: SparkSession, include_driver: bool =
expected_cores += 1

return expected_cores, config


@contextmanager
def override_job_description(spark_session: SparkSession, job_description: str):
"""
Override Spark job description.
Unlike ``spark_session.sparkContext.setJobDescription``, this method resets job description
before exiting the context manager, instead of keeping it.
If user set custom description, it will be left intact.
"""
spark_context = spark_session.sparkContext
original_description = spark_context.getLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY)

try:
spark_context.setLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY, original_description or job_description)
yield
finally:
spark_context.setLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY, original_description) # type: ignore[arg-type]
3 changes: 2 additions & 1 deletion onetl/base/base_db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

if TYPE_CHECKING:
from etl_entities.hwm import HWM
from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import StructField, StructType


Expand Down Expand Up @@ -106,6 +106,7 @@ class BaseDBConnection(BaseConnection):
Implements generic methods for reading and writing dataframe from/to database-like source
"""

spark: SparkSession
Dialect = BaseDBDialect

@property
Expand Down
4 changes: 3 additions & 1 deletion onetl/base/base_file_df_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from onetl.base.pure_path_protocol import PurePathProtocol

if TYPE_CHECKING:
from pyspark.sql import DataFrame, DataFrameReader, DataFrameWriter
from pyspark.sql import DataFrame, DataFrameReader, DataFrameWriter, SparkSession
from pyspark.sql.types import StructType


Expand Down Expand Up @@ -72,6 +72,8 @@ class BaseFileDFConnection(BaseConnection):
.. versionadded:: 0.9.0
"""

spark: SparkSession

@abstractmethod
def check_if_format_supported(
self,
Expand Down
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/clickhouse/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def jdbc_params(self) -> dict:
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}]"

@staticmethod
def _build_statement(
statement: str,
Expand Down
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def package_spark_3_2(cls) -> str:
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]"

@property
def jdbc_url(self) -> str:
return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}"
Expand Down
18 changes: 15 additions & 3 deletions onetl/connection/db_connection/hive/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
except (ImportError, AttributeError):
from pydantic import validator # type: ignore[no-redef, assignment]

from onetl._util.spark import inject_spark_param
from onetl._util.spark import inject_spark_param, override_job_description
from onetl._util.sql import clear_statement
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.hive.dialect import HiveDialect
Expand Down Expand Up @@ -158,6 +158,9 @@ def get_current(cls, spark: SparkSession):
def instance_url(self) -> str:
return self.cluster

def __str__(self):
return f"{self.__class__.__name__}[{self.cluster}]"

@slot
def check(self):
log.debug("|%s| Detecting current cluster...", self.__class__.__name__)
Expand Down Expand Up @@ -210,7 +213,11 @@ def sql(
log.info("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, query)

df = self._execute_sql(query)
with override_job_description(
self.spark,
f"{self}.sql()",
):
df = self._execute_sql(query)
log.info("|Spark| DataFrame successfully created from SQL statement")
return df

Expand All @@ -236,7 +243,12 @@ def execute(
log.info("|%s| Executing statement:", self.__class__.__name__)
log_lines(log, statement)

self._execute_sql(statement).collect()
with override_job_description(
self.spark,
f"{self}.execute()",
):
self._execute_sql(statement).collect()

log.info("|%s| Call succeeded", self.__class__.__name__)

@slot
Expand Down
7 changes: 6 additions & 1 deletion onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from typing import TYPE_CHECKING, Any

from onetl._util.spark import override_job_description
from onetl._util.sql import clear_statement
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.jdbc_connection.dialect import JDBCDialect
Expand Down Expand Up @@ -92,7 +93,11 @@ def sql(
log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__)
log_lines(log, query)

df = self._query_on_executor(query, self.SQLOptions.parse(options))
with override_job_description(
self.spark,
f"{self}.sql()",
):
df = self._query_on_executor(query, self.SQLOptions.parse(options))

log.info("|Spark| DataFrame successfully created from SQL statement ")
return df
Expand Down
54 changes: 31 additions & 23 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pydantic import Field, PrivateAttr, SecretStr, validator # type: ignore[no-redef, assignment]

from onetl._util.java import get_java_gateway, try_import_java_class
from onetl._util.spark import get_spark_version, stringify
from onetl._util.spark import get_spark_version, override_job_description, stringify
from onetl._util.sql import clear_statement
from onetl._util.version import Version
from onetl.connection.db_connection.jdbc_mixin.options import (
Expand Down Expand Up @@ -204,20 +204,23 @@ def fetch(
log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log_lines(log, query)

df = self._query_on_driver(
query,
(
self.FetchOptions.parse(options.dict()) # type: ignore
if isinstance(options, JDBCMixinOptions)
else self.FetchOptions.parse(options)
),
call_options = (
self.FetchOptions.parse(options.dict()) # type: ignore
if isinstance(options, JDBCMixinOptions)
else self.FetchOptions.parse(options)
)

log.info(
"|%s| Query succeeded, resulting in-memory dataframe contains %d rows",
self.__class__.__name__,
df.count(),
)
with override_job_description(
self.spark,
f"{self}.fetch()",
):
df = self._query_on_driver(query, call_options)

log.info(
"|%s| Query succeeded, resulting in-memory dataframe contains %d rows",
self.__class__.__name__,
df.count(),
)
return df

@slot
Expand Down Expand Up @@ -273,17 +276,22 @@ def execute(
if isinstance(options, JDBCMixinOptions)
else self.ExecuteOptions.parse(options)
)
df = self._call_on_driver(statement, call_options)

if df is not None:
rows_count = df.count()
log.info(
"|%s| Execution succeeded, resulting in-memory dataframe contains %d rows",
self.__class__.__name__,
rows_count,
)
else:
log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__)
with override_job_description(
self.spark,
f"{self}.execute()",
):
df = self._call_on_driver(statement, call_options)

if df is not None:
rows_count = df.count()
log.info(
"|%s| Execution succeeded, resulting in-memory dataframe contains %d rows",
self.__class__.__name__,
rows_count,
)
else:
log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__)
return df

@validator("spark")
Expand Down
7 changes: 5 additions & 2 deletions onetl/connection/db_connection/kafka/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def get_min_max_values(
# https://kafka.apache.org/22/javadoc/org/apache/kafka/clients/consumer/KafkaConsumer.html#partitionsFor-java.lang.String-
partition_infos = consumer.partitionsFor(source)

jvm = self.spark._jvm
jvm = self.spark._jvm # type: ignore[attr-defined]
topic_partitions = [
jvm.org.apache.kafka.common.TopicPartition(source, p.partition()) # type: ignore[union-attr]
for p in partition_infos
Expand Down Expand Up @@ -542,6 +542,9 @@ def get_min_max_values(
def instance_url(self):
return "kafka://" + self.cluster

def __str__(self):
return f"{self.__class__.__name__}[{self.cluster}]"

@root_validator(pre=True)
def _get_addresses_by_cluster(cls, values):
cluster = values.get("cluster")
Expand Down Expand Up @@ -639,7 +642,7 @@ def _get_java_consumer(self):
return consumer_class(connection_properties)

def _get_topics(self, timeout: int = 10) -> set[str]:
jvm = self.spark._jvm
jvm = self.spark._jvm # type: ignore[attr-defined]
# Maybe we should not pass explicit timeout at all,
# and instead use default.api.timeout.ms which is configurable via self.extra.
# Think about this next time if someone see issues in real use
Expand Down
20 changes: 14 additions & 6 deletions onetl/connection/db_connection/mongodb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from onetl._util.classproperty import classproperty
from onetl._util.java import try_import_java_class
from onetl._util.scala import get_default_scala_version
from onetl._util.spark import get_spark_version
from onetl._util.spark import get_spark_version, override_job_description
from onetl._util.version import Version
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.mongodb.dialect import MongoDBDialect
Expand Down Expand Up @@ -347,17 +347,25 @@ def pipeline(
if pipeline:
read_options["aggregation.pipeline"] = json.dumps(pipeline)
read_options["connection.uri"] = self.connection_url
spark_reader = self.spark.read.format("mongodb").options(**read_options)

if df_schema:
spark_reader = spark_reader.schema(df_schema)
with override_job_description(
self.spark,
f"{self}.pipeline()",
):
spark_reader = self.spark.read.format("mongodb").options(**read_options)

return spark_reader.load()
if df_schema:
spark_reader = spark_reader.schema(df_schema)

return spark_reader.load()

@property
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]"

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
Expand Down Expand Up @@ -532,7 +540,7 @@ def _check_java_class_imported(cls, spark):
return spark

def _collection_exists(self, source: str) -> bool:
jvm = self.spark._jvm
jvm = self.spark._jvm # type: ignore[attr-defined]
client = jvm.com.mongodb.client.MongoClients.create(self.connection_url) # type: ignore
collections = set(client.getDatabase(self.database).listCollectionNames().iterator())
if source in collections:
Expand Down
9 changes: 9 additions & 0 deletions onetl/connection/db_connection/mssql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,12 @@ def instance_url(self) -> str:
# for backward compatibility keep port number in legacy HWM instance url
port = self.port or 1433
return f"{self.__class__.__name__.lower()}://{self.host}:{port}/{self.database}"

def __str__(self):
extra_dict = self.extra.dict(by_alias=True)
instance_name = extra_dict.get("instanceName")
if instance_name:
return rf"{self.__class__.__name__}[{self.host}\{instance_name}/{self.database}]"

port = self.port or 1433
return f"{self.__class__.__name__}[{self.host}:{port}/{self.database}]"
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/mysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,6 @@ def jdbc_params(self) -> dict:
@property
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}]"
Loading

0 comments on commit e56136d

Please sign in to comment.