diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ff260645a..06b9c2b35 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -114,7 +114,7 @@ repos: - black==24.4.2 - repo: https://github.com/pycqa/bandit - rev: 1.7.10 + rev: 1.8.0 hooks: - id: bandit args: diff --git a/docs/changelog/0.12.5.rst b/docs/changelog/0.12.5.rst new file mode 100644 index 000000000..f08df5d0f --- /dev/null +++ b/docs/changelog/0.12.5.rst @@ -0,0 +1,15 @@ +0.12.5 (2024-12-03) +=================== + +Improvements +------------ + +- Use ``sipHash64`` instead of ``md5`` in Clickhouse for reading data with ``{"partitioning_mode": "hash"}``, as it is 5 times faster. +- Use ``hashtext`` instead of ``md5`` in Postgres for reading data with ``{"partitioning_mode": "hash"}``, as it is 3-5 times faster. +- Use ``BINARY_CHECKSUM`` instead of ``HASHBYTES`` in MSSQL for reading data with ``{"partitioning_mode": "hash"}``, as it is 5 times faster. + +Big fixes +--------- + +- In JDBC sources wrap ``MOD(partitionColumn, numPartitions)`` with ``ABS(...)`` to make al returned values positive. This prevents data sked. +- Fix reading table data from MSSQL using ``{"partitioning_mode": "hash"}`` with ``partitionColumn`` of integer type. diff --git a/docs/changelog/index.rst b/docs/changelog/index.rst index 647dcc1f5..6f0532bf6 100644 --- a/docs/changelog/index.rst +++ b/docs/changelog/index.rst @@ -3,6 +3,7 @@ :caption: Changelog DRAFT + 0.12.5 0.12.4 0.12.3 0.12.2 diff --git a/onetl/VERSION b/onetl/VERSION index e01e0ddd8..43c2417ca 100644 --- a/onetl/VERSION +++ b/onetl/VERSION @@ -1 +1 @@ -0.12.4 +0.12.5 diff --git a/onetl/connection/db_connection/clickhouse/dialect.py b/onetl/connection/db_connection/clickhouse/dialect.py index 1ee213d0e..9e2a68b4a 100644 --- a/onetl/connection/db_connection/clickhouse/dialect.py +++ b/onetl/connection/db_connection/clickhouse/dialect.py @@ -10,10 +10,15 @@ class ClickhouseDialect(JDBCDialect): def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: - return f"halfMD5({partition_column}) % {num_partitions}" + # SipHash is 3 times faster thah MD5 + # https://clickhouse.com/docs/en/sql-reference/functions/hash-functions#siphash64 + return f"sipHash64({partition_column}) % {num_partitions}" def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" + # Return positive value even for negative input. + # Don't use positiveModulo as it is 4-5 times slower: + # https://clickhouse.com/docs/en/sql-reference/functions/arithmetic-functions#positivemoduloa-b + return f"abs({partition_column} % {num_partitions})" def get_max_value(self, value: Any) -> str: # Max function in Clickhouse returns 0 instead of NULL for empty table diff --git a/onetl/connection/db_connection/jdbc_connection/options.py b/onetl/connection/db_connection/jdbc_connection/options.py index c04e88509..11739736b 100644 --- a/onetl/connection/db_connection/jdbc_connection/options.py +++ b/onetl/connection/db_connection/jdbc_connection/options.py @@ -155,11 +155,10 @@ class Config: .. note:: Column type depends on :obj:`~partitioning_mode`. - * ``partitioning_mode="range"`` requires column to be an integer or date (can be NULL, but not recommended). - * ``partitioning_mode="hash"`` requires column to be an string (NOT NULL). + * ``partitioning_mode="range"`` requires column to be an integer, date or timestamp (can be NULL, but not recommended). + * ``partitioning_mode="hash"`` accepts any column type (NOT NULL). * ``partitioning_mode="mod"`` requires column to be an integer (NOT NULL). - See documentation for :obj:`~partitioning_mode` for more details""" num_partitions: PositiveInt = Field(default=1, alias="numPartitions") @@ -256,6 +255,10 @@ class Config: Where ``stride=(upper_bound - lower_bound) / num_partitions``. + .. note:: + + Can be used only with columns of integer, date or timestamp types. + .. note:: :obj:`~lower_bound`, :obj:`~upper_bound` and :obj:`~num_partitions` are used just to @@ -297,7 +300,7 @@ class Config: .. note:: The hash function implementation depends on RDBMS. It can be ``MD5`` or any other fast hash function, - or expression based on this function call. + or expression based on this function call. Usually such functions accepts any column type as an input. * ``mod`` Allocate each executor a set of values based on modulus of the :obj:`~partition_column` column. @@ -325,6 +328,10 @@ class Config: SELECT ... FROM table WHERE (partition_column mod num_partitions) = num_partitions-1 -- upper_bound + .. note:: + + Can be used only with columns of integer type. + .. versionadded:: 0.5.0 Examples diff --git a/onetl/connection/db_connection/mssql/dialect.py b/onetl/connection/db_connection/mssql/dialect.py index 3cb809ad2..96d5a175c 100644 --- a/onetl/connection/db_connection/mssql/dialect.py +++ b/onetl/connection/db_connection/mssql/dialect.py @@ -8,12 +8,15 @@ class MSSQLDialect(JDBCDialect): - # https://docs.microsoft.com/ru-ru/sql/t-sql/functions/hashbytes-transact-sql?view=sql-server-ver16 def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: - return f"CONVERT(BIGINT, HASHBYTES ('SHA', {partition_column})) % {num_partitions}" + # CHECKSUM/BINARY_CHECKSUM are faster than MD5 in 5 times: + # https://stackoverflow.com/a/4691861/23601543 + # https://learn.microsoft.com/en-us/sql/t-sql/functions/checksum-transact-sql?view=sql-server-ver16 + return f"ABS(BINARY_CHECKSUM({partition_column})) % {num_partitions}" def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" + # Return positive value even for negative input + return f"ABS({partition_column} % {num_partitions})" def get_sql_query( self, diff --git a/onetl/connection/db_connection/mysql/dialect.py b/onetl/connection/db_connection/mysql/dialect.py index 5b59bc385..07b7ad6e3 100644 --- a/onetl/connection/db_connection/mysql/dialect.py +++ b/onetl/connection/db_connection/mysql/dialect.py @@ -9,10 +9,13 @@ class MySQLDialect(JDBCDialect): def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: - return f"MOD(CONV(CONV(RIGHT(MD5({partition_column}), 16), 16, 2), 2, 10), {num_partitions})" + # MD5 is the fastest hash function https://stackoverflow.com/a/3118889/23601543 + # But it returns 32 char string (128 bit), which we need to convert to integer + return f"CAST(CONV(RIGHT(MD5({partition_column}), 16), 16, 10) AS UNSIGNED) % {num_partitions}" def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: - return f"MOD({partition_column}, {num_partitions})" + # Return positive value even for negative input + return f"ABS({partition_column} % {num_partitions})" def escape_column(self, value: str) -> str: return f"`{value}`" diff --git a/onetl/connection/db_connection/oracle/dialect.py b/onetl/connection/db_connection/oracle/dialect.py index c7a739039..5d4844003 100644 --- a/onetl/connection/db_connection/oracle/dialect.py +++ b/onetl/connection/db_connection/oracle/dialect.py @@ -48,7 +48,8 @@ def get_partition_column_hash(self, partition_column: str, num_partitions: int) return f"ora_hash({partition_column}, {num_partitions - 1})" def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: - return f"MOD({partition_column}, {num_partitions})" + # Return positive value even for negative input + return f"ABS(MOD({partition_column}, {num_partitions}))" def _serialize_datetime(self, value: datetime) -> str: result = value.strftime("%Y-%m-%d %H:%M:%S") diff --git a/onetl/connection/db_connection/postgres/dialect.py b/onetl/connection/db_connection/postgres/dialect.py index 1dca8ec99..30509addb 100644 --- a/onetl/connection/db_connection/postgres/dialect.py +++ b/onetl/connection/db_connection/postgres/dialect.py @@ -9,12 +9,14 @@ class PostgresDialect(NotSupportHint, JDBCDialect): - # https://stackoverflow.com/a/9812029 def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: - return f"('x'||right(md5('{partition_column}'), 16))::bit(32)::bigint % {num_partitions}" + # hashtext is about 3-5 times faster than MD5 (tested locally) + # https://postgrespro.com/list/thread-id/1506406 + return f"abs(hashtext({partition_column}::text)) % {num_partitions}" def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" + # Return positive value even for negative input + return f"abs({partition_column} % {num_partitions})" def _serialize_datetime(self, value: datetime) -> str: result = value.isoformat() diff --git a/onetl/connection/db_connection/teradata/dialect.py b/onetl/connection/db_connection/teradata/dialect.py index b7fc9c47f..81f5c1f52 100644 --- a/onetl/connection/db_connection/teradata/dialect.py +++ b/onetl/connection/db_connection/teradata/dialect.py @@ -13,7 +13,8 @@ def get_partition_column_hash(self, partition_column: str, num_partitions: int) return f"HASHAMP(HASHBUCKET(HASHROW({partition_column}))) mod {num_partitions}" def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} mod {num_partitions}" + # Return positive value even for negative input + return f"ABS({partition_column} mod {num_partitions})" def _serialize_datetime(self, value: datetime) -> str: result = value.isoformat() diff --git a/tests/fixtures/processing/base_processing.py b/tests/fixtures/processing/base_processing.py index c3159f334..5bf845aa3 100644 --- a/tests/fixtures/processing/base_processing.py +++ b/tests/fixtures/processing/base_processing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import secrets from abc import ABC, abstractmethod from collections import defaultdict from datetime import date, datetime, timedelta @@ -137,7 +138,7 @@ def create_pandas_df( elif "float" in column_name: values[column].append(float(f"{i}.{i}")) elif "text" in column_name: - values[column].append("This line is made to test the work") + values[column].append(secrets.token_hex(16)) elif "datetime" in column_name: rand_second = randint(0, i * time_multiplier) # noqa: S311 values[column].append(self.current_datetime() + timedelta(seconds=rand_second)) diff --git a/tests/fixtures/processing/clickhouse.py b/tests/fixtures/processing/clickhouse.py index bf0b2f3e7..65d06d82b 100644 --- a/tests/fixtures/processing/clickhouse.py +++ b/tests/fixtures/processing/clickhouse.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import secrets from collections import defaultdict from datetime import date, datetime, timedelta from logging import getLogger @@ -74,7 +75,7 @@ def create_pandas_df(self, min_id: int = 1, max_id: int | None = None) -> pandas elif "float" in column_name: values[column_name].append(float(f"{i}.{i}")) elif "text" in column_name: - values[column_name].append("This line is made to test the work") + values[column_name].append(secrets.token_hex(16)) elif "datetime" in column_name: rand_second = randint(0, i * time_multiplier) # noqa: S311 # Clickhouse DATETIME format has time range: 00:00:00 through 23:59:59 diff --git a/tests/fixtures/processing/mongodb.py b/tests/fixtures/processing/mongodb.py index bf3158a99..f6e063ccf 100644 --- a/tests/fixtures/processing/mongodb.py +++ b/tests/fixtures/processing/mongodb.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import secrets from collections import defaultdict from datetime import datetime, timedelta from logging import getLogger @@ -149,7 +150,7 @@ def create_pandas_df(self, min_id: int = 1, max_id: int | None = None) -> pandas elif "float" in column_name: values[column_name].append(float(f"{i}.{i}")) elif "text" in column_name: - values[column_name].append("This line is made to test the work") + values[column_name].append(secrets.token_hex(16)) elif "datetime" in column_name: rand_second = randint(0, i * time_multiplier) # noqa: S311 now = self.current_datetime() + timedelta(seconds=rand_second) diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py index 72314b5b3..531cfd065 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py @@ -7,6 +7,8 @@ except ImportError: pytest.skip("Missing pandas", allow_module_level=True) +from onetl._util.spark import get_spark_version +from onetl._util.version import Version from onetl.connection import Clickhouse from onetl.db import DBReader from tests.util.rand import rand_str @@ -38,15 +40,201 @@ def test_clickhouse_reader_snapshot(spark, processing, load_table_data): ) +def test_clickhouse_reader_snapshot_with_partitioning_mode_range_int(spark, processing, load_table_data): + from pyspark.sql.functions import spark_partition_id + + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=clickhouse, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": "hwm_int", + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows with very low variance. + average_count_per_partition = table_df.count() // table_df.rdd.getNumPartitions() + min_count_per_partition = average_count_per_partition - 1 + max_count_per_partition = average_count_per_partition + 1 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +@pytest.mark.parametrize( + "bounds", + [ + pytest.param({"lowerBound": "50"}, id="lower_bound"), + pytest.param({"upperBound": "70"}, id="upper_bound"), + pytest.param({"lowerBound": "50", "upperBound": "70"}, id="both_bounds"), + ], +) +def test_clickhouse_reader_snapshot_with_partitioning_mode_range_int_explicit_bounds( + spark, + processing, + load_table_data, + bounds, +): + from pyspark.sql.functions import spark_partition_id + + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=clickhouse, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": "hwm_int", + "numPartitions": 3, + **bounds, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + +@pytest.mark.parametrize( + "column", + [ + "hwm_date", + "hwm_datetime", + ], +) +def test_clickhouse_reader_snapshot_with_partitioning_mode_range_date_datetime( + spark, + processing, + load_table_data, + column, +): + spark_version = get_spark_version(spark) + if spark_version < Version("2.4"): + # https://issues.apache.org/jira/browse/SPARK-22814 + pytest.skip("partitionColumn of date/datetime is supported only since 2.4.0") + + from pyspark.sql.functions import spark_partition_id + + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=clickhouse, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + +@pytest.mark.parametrize( + "column", + [ + "float_value", + "text_string", + ], +) +def test_clickhouse_reader_snapshot_with_partitioning_mode_range_unsupported_column_type( + spark, + processing, + load_table_data, + column, +): + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=clickhouse, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + + with pytest.raises(Exception): + reader.run() + + @pytest.mark.parametrize( - "mode, column", + "column", [ - ("range", "id_int"), - ("hash", "text_string"), - ("mod", "id_int"), + # all column types are supported + "hwm_int", + "hwm_date", + "hwm_datetime", + "float_value", + "text_string", ], ) -def test_clickhouse_reader_snapshot_partitioning_mode(mode, column, spark, processing, load_table_data): +def test_clickhouse_reader_snapshot_with_partitioning_mode_hash(spark, processing, load_table_data, column): + from pyspark.sql.functions import spark_partition_id + clickhouse = Clickhouse( host=processing.host, port=processing.port, @@ -59,13 +247,63 @@ def test_clickhouse_reader_snapshot_partitioning_mode(mode, column, spark, proce reader = DBReader( connection=clickhouse, source=load_table_data.full_name, - options=Clickhouse.ReadOptions( - partitioning_mode=mode, - partition_column=column, - num_partitions=5, - ), + options={ + "partitioning_mode": "hash", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows, + # with some variance caused by randomness & hash distribution + min_count_per_partition = 10 + max_count_per_partition = 55 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +@pytest.mark.parametrize( + "column", + [ + "hwm_int", + "float_value", + ], +) +def test_clickhouse_reader_snapshot_with_partitioning_mode_mod_number(spark, processing, load_table_data, column): + from pyspark.sql.functions import spark_partition_id + + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, ) + reader = DBReader( + connection=clickhouse, + source=load_table_data.full_name, + options={ + "partitioning_mode": "mod", + "partitionColumn": column, + "numPartitions": 3, + }, + ) table_df = reader.run() processing.assert_equal_df( @@ -75,7 +313,99 @@ def test_clickhouse_reader_snapshot_partitioning_mode(mode, column, spark, proce order_by="id_int", ) - assert table_df.rdd.getNumPartitions() == 5 + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows with very low variance. + average_count_per_partition = table_df.count() // table_df.rdd.getNumPartitions() + min_count_per_partition = average_count_per_partition - 1 + max_count_per_partition = average_count_per_partition + 1 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +# Apparently, Clickhouse supports `date % number` operation +def test_clickhouse_reader_snapshot_with_partitioning_mode_mod_date(spark, processing, load_table_data): + from pyspark.sql.functions import spark_partition_id + + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=clickhouse, + source=load_table_data.full_name, + options={ + "partitioning_mode": "mod", + "partitionColumn": "hwm_date", + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows, + # with some variance caused by randomness & hash distribution + min_count_per_partition = 10 + max_count_per_partition = 55 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +@pytest.mark.parametrize( + "column", + [ + "hwm_datetime", + "text_string", + ], +) +def test_clickhouse_reader_snapshot_with_partitioning_mode_mod_unsupported_column_type( + spark, + processing, + load_table_data, + column, +): + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=clickhouse, + source=load_table_data.full_name, + options={ + "partitioning_mode": "mod", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + + with pytest.raises(Exception, match=r"Illegal types .* of arguments of function modulo"): + reader.run() def test_clickhouse_reader_snapshot_without_set_database(spark, processing, load_table_data): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py index 781121e4b..a6129d796 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py @@ -7,6 +7,8 @@ except ImportError: pytest.skip("Missing pandas", allow_module_level=True) +from onetl._util.spark import get_spark_version +from onetl._util.version import Version from onetl.connection import MSSQL from onetl.db import DBReader from tests.util.rand import rand_str @@ -39,15 +41,165 @@ def test_mssql_reader_snapshot(spark, processing, load_table_data): ) +def test_mssql_reader_snapshot_with_partitioning_mode_range_int(spark, processing, load_table_data): + from pyspark.sql.functions import spark_partition_id + + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + + reader = DBReader( + connection=mssql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": "hwm_int", + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows with very low variance. + average_count_per_partition = table_df.count() // table_df.rdd.getNumPartitions() + min_count_per_partition = average_count_per_partition - 1 + max_count_per_partition = average_count_per_partition + 1 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +@pytest.mark.parametrize( + "bounds", + [ + pytest.param({"lowerBound": "50"}, id="lower_bound"), + pytest.param({"upperBound": "70"}, id="upper_bound"), + pytest.param({"lowerBound": "50", "upperBound": "70"}, id="both_bounds"), + ], +) +def test_mssql_reader_snapshot_with_partitioning_mode_range_int_explicit_bounds( + spark, + processing, + load_table_data, + bounds, +): + from pyspark.sql.functions import spark_partition_id + + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + + reader = DBReader( + connection=mssql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": "hwm_int", + "numPartitions": 3, + **bounds, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + +# sometimes fails with 'Conversion failed when converting date and/or time from character string.' +@pytest.mark.flaky(reruns=10) +@pytest.mark.parametrize( + "column", + [ + "hwm_date", + "hwm_datetime", + ], +) +def test_mssql_reader_snapshot_with_partitioning_mode_range_date_datetime(spark, processing, load_table_data, column): + spark_version = get_spark_version(spark) + if spark_version < Version("2.4"): + # https://issues.apache.org/jira/browse/SPARK-22814 + pytest.skip("partitionColumn of date/datetime is supported only since 2.4.0") + + from pyspark.sql.functions import spark_partition_id + + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + + reader = DBReader( + connection=mssql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + @pytest.mark.parametrize( - "mode, column", + "column", [ - ("range", "id_int"), - ("hash", "text_string"), - ("mod", "id_int"), + "float_value", + "text_string", ], ) -def test_mssql_reader_snapshot_partitioning_mode(mode, column, spark, processing, load_table_data): +def test_mssql_reader_snapshot_with_partitioning_mode_range_unsupported_column_type( + spark, + processing, + load_table_data, + column, +): mssql = MSSQL( host=processing.host, port=processing.port, @@ -55,19 +207,58 @@ def test_mssql_reader_snapshot_partitioning_mode(mode, column, spark, processing password=processing.password, database=processing.database, spark=spark, - extra={"trustServerCertificate": "true"}, # avoid SSL problem + extra={"trustServerCertificate": "true"}, ) reader = DBReader( connection=mssql, source=load_table_data.full_name, - options=MSSQL.ReadOptions( - partitioning_mode=mode, - partition_column=column, - num_partitions=5, - ), + options={ + "partitioning_mode": "range", + "partitionColumn": column, + "numPartitions": 3, + }, ) + with pytest.raises(Exception): + reader.run() + + +@pytest.mark.parametrize( + "column", + [ + # all column types are supported + "hwm_int", + "hwm_date", + "hwm_datetime", + "float_value", + "text_string", + # hash of the entire row is supported as well + "*", + ], +) +def test_mssql_reader_snapshot_with_partitioning_mode_hash(spark, processing, load_table_data, column): + from pyspark.sql.functions import spark_partition_id + + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + + reader = DBReader( + connection=mssql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "hash", + "partitionColumn": column, + "numPartitions": 3, + }, + ) table_df = reader.run() processing.assert_equal_df( @@ -77,7 +268,107 @@ def test_mssql_reader_snapshot_partitioning_mode(mode, column, spark, processing order_by="id_int", ) - assert table_df.rdd.getNumPartitions() == 5 + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows, + # with some variance caused by randomness & hash distribution + min_count_per_partition = 10 + max_count_per_partition = 55 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +def test_mssql_reader_snapshot_with_partitioning_mode_mod(spark, processing, load_table_data): + from pyspark.sql.functions import spark_partition_id + + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + + reader = DBReader( + connection=mssql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "mod", + "partitionColumn": "hwm_int", + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows with very low variance. + average_count_per_partition = table_df.count() // table_df.rdd.getNumPartitions() + min_count_per_partition = average_count_per_partition - 1 + max_count_per_partition = average_count_per_partition + 1 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +@pytest.mark.parametrize( + "column", + [ + "hwm_date", + "hwm_datetime", + "float_value", + "text_string", + ], +) +def test_mssql_reader_snapshot_with_partitioning_mode_mod_unsupported_column_type( + spark, + processing, + load_table_data, + column, +): + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + + reader = DBReader( + connection=mssql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "mod", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + + with pytest.raises( + Exception, + match=r"are incompatible in the modulo operator|Conversion failed .* to data type int", + ): + table_df = reader.run() + table_df.count() def test_mssql_reader_snapshot_with_columns(spark, processing, load_table_data): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py index 3f12746b8..64e1f4d27 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py @@ -7,6 +7,8 @@ except ImportError: pytest.skip("Missing pandas", allow_module_level=True) +from onetl._util.spark import get_spark_version +from onetl._util.version import Version from onetl.connection import MySQL from onetl.db import DBReader from tests.util.rand import rand_str @@ -39,15 +41,67 @@ def test_mysql_reader_snapshot(spark, processing, load_table_data): ) +def test_mysql_reader_snapshot_with_partitioning_mode_range_int(spark, processing, load_table_data): + from pyspark.sql.functions import spark_partition_id + + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=mysql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": "hwm_int", + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows with very low variance. + average_count_per_partition = table_df.count() // table_df.rdd.getNumPartitions() + min_count_per_partition = average_count_per_partition - 1 + max_count_per_partition = average_count_per_partition + 1 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + @pytest.mark.parametrize( - "mode, column", + "bounds", [ - ("range", "id_int"), - ("hash", "text_string"), - ("mod", "id_int"), + pytest.param({"lowerBound": "50"}, id="lower_bound"), + pytest.param({"upperBound": "70"}, id="upper_bound"), + pytest.param({"lowerBound": "50", "upperBound": "70"}, id="both_bounds"), ], ) -def test_mysql_reader_snapshot_partitioning_mode(mode, column, spark, processing, load_table_data): +def test_mysql_reader_snapshot_with_partitioning_mode_range_int_explicit_bounds( + spark, + processing, + load_table_data, + bounds, +): + from pyspark.sql.functions import spark_partition_id + mysql = MySQL( host=processing.host, port=processing.port, @@ -60,13 +114,60 @@ def test_mysql_reader_snapshot_partitioning_mode(mode, column, spark, processing reader = DBReader( connection=mysql, source=load_table_data.full_name, - options=MySQL.ReadOptions( - partitioning_mode=mode, - partition_column=column, - num_partitions=5, - ), + options={ + "partitioning_mode": "range", + "partitionColumn": "hwm_int", + "numPartitions": 3, + **bounds, + }, ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + +@pytest.mark.parametrize( + "column", + [ + "hwm_date", + "hwm_datetime", + ], +) +def test_mysql_reader_snapshot_with_partitioning_mode_range_date_datetime(spark, processing, load_table_data, column): + spark_version = get_spark_version(spark) + if spark_version < Version("2.4"): + # https://issues.apache.org/jira/browse/SPARK-22814 + pytest.skip("partitionColumn of date/datetime is supported only since 2.4.0") + + from pyspark.sql.functions import spark_partition_id + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=mysql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": column, + "numPartitions": 3, + }, + ) table_df = reader.run() processing.assert_equal_df( @@ -76,7 +177,197 @@ def test_mysql_reader_snapshot_partitioning_mode(mode, column, spark, processing order_by="id_int", ) - assert table_df.rdd.getNumPartitions() == 5 + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + +@pytest.mark.parametrize( + "column", + [ + "float_value", + "text_string", + ], +) +def test_mysql_reader_snapshot_with_partitioning_mode_range_unsupported_column_type( + spark, + processing, + load_table_data, + column, +): + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=mysql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + + with pytest.raises(Exception): + reader.run() + + +@pytest.mark.parametrize( + "column", + [ + # all column types are supported + "hwm_int", + "hwm_date", + "hwm_datetime", + "float_value", + "text_string", + ], +) +def test_mysql_reader_snapshot_with_partitioning_mode_hash(spark, processing, load_table_data, column): + from pyspark.sql.functions import spark_partition_id + + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=mysql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "hash", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows, + # with some variance caused by randomness & hash distribution (+- 50% range is wide enough) + min_count_per_partition = 10 + max_count_per_partition = 55 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +@pytest.mark.parametrize( + "column", + [ + "hwm_int", + "float_value", + ], +) +def test_mysql_reader_snapshot_with_partitioning_mode_mod_number(spark, processing, load_table_data, column): + from pyspark.sql.functions import spark_partition_id + + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=mysql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "mod", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows with very low variance. + average_count_per_partition = table_df.count() // table_df.rdd.getNumPartitions() + min_count_per_partition = average_count_per_partition - 1 + max_count_per_partition = average_count_per_partition + 1 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +# Apparently, MySQL supports modulus for any column type +@pytest.mark.parametrize( + "column", + [ + "hwm_date", + "hwm_datetime", + "text_string", + ], +) +def test_mysql_reader_snapshot_with_partitioning_mode_mod_other_column_type(spark, processing, load_table_data, column): + from pyspark.sql.functions import spark_partition_id + + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=mysql, + source=load_table_data.full_name, + options={ + "partitioning_mode": "mod", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # for some reason, `MOD(text_string, N)` result is very skewed, don't assert on that. def test_mysql_reader_snapshot_with_not_set_database(spark, processing, load_table_data): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py index d9864967c..f5a8491a0 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py @@ -39,15 +39,156 @@ def test_oracle_reader_snapshot(spark, processing, load_table_data): ) +def test_oracle_reader_snapshot_with_partitioning_mode_range_int(spark, processing, load_table_data): + from pyspark.sql.functions import spark_partition_id + + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + sid=processing.sid, + service_name=processing.service_name, + ) + + reader = DBReader( + connection=oracle, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": "hwm_int", + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows with very low variance. + average_count_per_partition = table_df.count() // table_df.rdd.getNumPartitions() + min_count_per_partition = average_count_per_partition - 1 + max_count_per_partition = average_count_per_partition + 1 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +@pytest.mark.parametrize( + "bounds", + [ + pytest.param({"lowerBound": "50"}, id="lower_bound"), + pytest.param({"upperBound": "70"}, id="upper_bound"), + pytest.param({"lowerBound": "50", "upperBound": "70"}, id="both_bounds"), + ], +) +def test_oracle_reader_snapshot_with_partitioning_mode_range_int_explicit_bounds( + spark, + processing, + load_table_data, + bounds, +): + from pyspark.sql.functions import spark_partition_id + + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + sid=processing.sid, + service_name=processing.service_name, + ) + + reader = DBReader( + connection=oracle, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": "hwm_int", + "numPartitions": 3, + **bounds, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + +@pytest.mark.parametrize( + "column", + [ + "hwm_date", + "hwm_datetime", + "float_value", + "text_string", + ], +) +def test_oracle_reader_snapshot_with_partitioning_mode_range_unsupported_column_type( + spark, + processing, + load_table_data, + column, +): + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + sid=processing.sid, + service_name=processing.service_name, + ) + + reader = DBReader( + connection=oracle, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + + with pytest.raises(Exception): + table_df = reader.run() + table_df.count() + + @pytest.mark.parametrize( - "mode, column", + "column", [ - ("range", "id_int"), - ("hash", "text_string"), - ("mod", "id_int"), + # all column types are supported + "hwm_int", + "hwm_date", + "hwm_datetime", + "float_value", + "text_string", ], ) -def test_oracle_reader_snapshot_partitioning_mode(mode, column, spark, processing, load_table_data): +def test_oracle_reader_snapshot_with_partitioning_mode_hash(spark, processing, load_table_data, column): + from pyspark.sql.functions import spark_partition_id + oracle = Oracle( host=processing.host, port=processing.port, @@ -61,13 +202,66 @@ def test_oracle_reader_snapshot_partitioning_mode(mode, column, spark, processin reader = DBReader( connection=oracle, source=load_table_data.full_name, - options=Oracle.ReadOptions( - partitioning_mode=mode, - partition_column=column, - num_partitions=5, - ), + options={ + "partitioning_mode": "hash", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows, + # with some variance caused by randomness & hash distribution + min_count_per_partition = 10 + max_count_per_partition = 55 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +# Apparently, Oracle supports modulus for text columns type +@pytest.mark.parametrize( + "column", + [ + "hwm_int", + "float_value", + ], +) +def test_oracle_reader_snapshot_with_partitioning_mode_mod_number(spark, processing, load_table_data, column): + from pyspark.sql.functions import spark_partition_id + + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + sid=processing.sid, + service_name=processing.service_name, ) + reader = DBReader( + connection=oracle, + source=load_table_data.full_name, + options={ + "partitioning_mode": "mod", + "partitionColumn": column, + "numPartitions": 3, + }, + ) table_df = reader.run() processing.assert_equal_df( @@ -77,7 +271,57 @@ def test_oracle_reader_snapshot_partitioning_mode(mode, column, spark, processin order_by="id_int", ) - assert table_df.rdd.getNumPartitions() == 5 + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows with very low variance. + average_count_per_partition = table_df.count() // table_df.rdd.getNumPartitions() + min_count_per_partition = average_count_per_partition - 1 + max_count_per_partition = average_count_per_partition + 1 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +@pytest.mark.parametrize( + "column", + [ + "hwm_date", + "hwm_datetime", + "text_string", + ], +) +def test_oracle_reader_snapshot_with_partitioning_mode_mod_unsupported_column_type( + spark, + processing, + load_table_data, + column, +): + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + sid=processing.sid, + service_name=processing.service_name, + ) + + reader = DBReader( + connection=oracle, + source=load_table_data.full_name, + options={ + "partitioning_mode": "mod", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + + with pytest.raises(Exception): + table_df = reader.run() + table_df.count() def test_oracle_reader_snapshot_with_columns(spark, processing, load_table_data): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py index 248a575b9..48d719d59 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py @@ -2,11 +2,14 @@ import pytest +from onetl._util.version import Version + try: import pandas except ImportError: pytest.skip("Missing pandas", allow_module_level=True) +from onetl._util.spark import get_spark_version from onetl.connection import Postgres from onetl.db import DBReader from tests.util.rand import rand_str @@ -278,7 +281,9 @@ def test_postgres_reader_snapshot_with_columns_and_where(spark, processing, load assert count_df.collect()[0][0] == table_df.count() -def test_postgres_reader_snapshot_with_pydantic_options(spark, processing, load_table_data): +def test_postgres_reader_snapshot_with_partitioning_mode_range_int(spark, processing, load_table_data): + from pyspark.sql.functions import spark_partition_id + postgres = Postgres( host=processing.host, port=processing.port, @@ -291,9 +296,71 @@ def test_postgres_reader_snapshot_with_pydantic_options(spark, processing, load_ reader = DBReader( connection=postgres, source=load_table_data.full_name, - options=Postgres.ReadOptions(fetchsize=500), + options={ + "partitioning_mode": "range", + "partitionColumn": "hwm_int", + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows with very low variance. + average_count_per_partition = table_df.count() // table_df.rdd.getNumPartitions() + min_count_per_partition = average_count_per_partition - 1 + max_count_per_partition = average_count_per_partition + 1 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +@pytest.mark.parametrize( + "bounds", + [ + pytest.param({"lowerBound": "50"}, id="lower_bound"), + pytest.param({"upperBound": "70"}, id="upper_bound"), + pytest.param({"lowerBound": "50", "upperBound": "70"}, id="both_bounds"), + ], +) +def test_postgres_reader_snapshot_with_partitioning_mode_range_int_explicit_bounds( + spark, + processing, + load_table_data, + bounds, +): + from pyspark.sql.functions import spark_partition_id + + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, ) + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": "hwm_int", + "numPartitions": 3, + **bounds, + }, + ) table_df = reader.run() processing.assert_equal_df( @@ -303,25 +370,214 @@ def test_postgres_reader_snapshot_with_pydantic_options(spark, processing, load_ order_by="id_int", ) + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + +@pytest.mark.parametrize( + "column", + [ + "hwm_date", + "hwm_datetime", + ], +) +def test_postgres_reader_snapshot_with_partitioning_mode_range_date_datetime( + spark, + processing, + load_table_data, + column, +): + spark_version = get_spark_version(spark) + if spark_version < Version("2.4"): + # https://issues.apache.org/jira/browse/SPARK-22814 + pytest.skip("partitionColumn of date/datetime is supported only since 2.4.0") + + from pyspark.sql.functions import spark_partition_id + + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + +@pytest.mark.parametrize( + "column", + [ + "float_value", + "text_string", + ], +) +def test_postgres_reader_snapshot_with_partitioning_mode_range_unsupported_column_type( + spark, + processing, + load_table_data, + column, +): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + options={ + "partitioning_mode": "range", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + + with pytest.raises(Exception): + reader.run() + @pytest.mark.parametrize( - "mode", + "column", [ - {"partitioning_mode": "range"}, - {"partitioning_mode": "hash"}, - {"partitioning_mode": "mod"}, + # all column types are supported + "hwm_int", + "hwm_date", + "hwm_datetime", + "float_value", + "text_string", ], ) +def test_postgres_reader_snapshot_with_partitioning_mode_hash(spark, processing, load_table_data, column): + from pyspark.sql.functions import spark_partition_id + + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + options={ + "partitioning_mode": "hash", + "partitionColumn": column, + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows, + # with some variance caused by randomness & hash distribution + min_count_per_partition = 10 + max_count_per_partition = 55 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + +def test_postgres_reader_snapshot_with_partitioning_mode_mod(spark, processing, load_table_data): + from pyspark.sql.functions import spark_partition_id + + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + options={ + "partitioning_mode": "mod", + "partitionColumn": "hwm_int", + "numPartitions": 3, + }, + ) + table_df = reader.run() + + processing.assert_equal_df( + schema=load_table_data.schema, + table=load_table_data.table, + df=table_df, + order_by="id_int", + ) + + assert table_df.rdd.getNumPartitions() == 3 + # So just check that any partition has at least 0 rows + assert table_df.groupBy(spark_partition_id()).count().count() == 3 + + # 100 rows per 3 partitions -> each partition should contain about ~33 rows with very low variance. + average_count_per_partition = table_df.count() // table_df.rdd.getNumPartitions() + min_count_per_partition = average_count_per_partition - 1 + max_count_per_partition = average_count_per_partition + 1 + + count_per_partition = table_df.groupBy(spark_partition_id()).count().collect() + for partition in count_per_partition: + assert min_count_per_partition <= partition["count"] <= max_count_per_partition + + @pytest.mark.parametrize( - "options", + "column", [ - {"numPartitions": "2", "partitionColumn": "hwm_int"}, - {"numPartitions": "2", "partitionColumn": "hwm_int", "lowerBound": "50"}, - {"numPartitions": "2", "partitionColumn": "hwm_int", "upperBound": "70"}, - {"fetchsize": "2"}, + "hwm_date", + "hwm_datetime", + "float_value", + "text_string", ], ) -def test_postgres_reader_different_options(spark, processing, load_table_data, options, mode): +def test_postgres_reader_snapshot_with_partitioning_mode_mod_unsupported_column_type( + spark, + processing, + load_table_data, + column, +): postgres = Postgres( host=processing.host, port=processing.port, @@ -334,8 +590,33 @@ def test_postgres_reader_different_options(spark, processing, load_table_data, o reader = DBReader( connection=postgres, source=load_table_data.full_name, - options=options.update(mode), + options={ + "partitioning_mode": "mod", + "partitionColumn": column, + "numPartitions": 3, + }, ) + + with pytest.raises(Exception, match="operator does not exist"): + reader.run() + + +def test_postgres_reader_snapshot_with_pydantic_options(spark, processing, load_table_data): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + options=Postgres.ReadOptions(fetchsize=500), + ) + table_df = reader.run() processing.assert_equal_df( diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py index 2bb9165e6..69100f351 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py @@ -123,6 +123,7 @@ def test_oracle_strategy_incremental( processing.assert_subset_df(df=second_df, other_frame=second_span) +@pytest.mark.flaky def test_oracle_strategy_incremental_nothing_to_read(spark, processing, prepare_schema_table): oracle = Oracle( host=processing.host, @@ -270,6 +271,7 @@ def test_oracle_strategy_incremental_wrong_hwm( reader.run() +@pytest.mark.flaky def test_oracle_strategy_incremental_explicit_hwm_type( spark, processing,