From 36fd574289f12d53a9ae4e37154ca7919a5b4573 Mon Sep 17 00:00:00 2001 From: Simon Eskildsen Date: Tue, 28 Jun 2022 19:43:09 +0000 Subject: [PATCH] benchmark: add suite --- .gitignore | 1 + README.md | 7 ++ data_diff/diff_tables.py | 13 ++- tests/common.py | 7 ++ tests/test_database_types.py | 217 +++++++++++++++++++++++++++-------- 5 files changed, 199 insertions(+), 46 deletions(-) diff --git a/.gitignore b/.gitignore index 943fa569..802b32f6 100644 --- a/.gitignore +++ b/.gitignore @@ -133,6 +133,7 @@ ml-25m* ratings*.csv drive mysqltuner.pl +benchmark_*.jsonl # Mac .DS_Store diff --git a/README.md b/README.md index 8ddca7d8..b41464a3 100644 --- a/README.md +++ b/README.md @@ -457,6 +457,13 @@ $ poetry run preql -f dev/prepare_db.pql bigquery:/// poetry run python3 -m data_diff postgresql://postgres:Password1@localhost/postgres rating postgresql://postgres:Password1@localhost/postgres rating_del1 --verbose ``` +**6. Run benchmarks (optional)** + +```shell-session +$ dev/benchmark.sh +``` + + # License [MIT License](https://github.com/datafold/data-diff/blob/master/LICENSE) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 143714f2..805f92d0 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod import time +import os from operator import attrgetter, methodcaller from collections import defaultdict from typing import List, Tuple, Iterator, Optional, Type @@ -28,7 +29,7 @@ logger = logging.getLogger("diff_tables") RECOMMENDED_CHECKSUM_DURATION = 10 - +BENCHMARK = os.environ.get("BENCHMARK", False) DEFAULT_BISECTION_THRESHOLD = 1024 * 16 DEFAULT_BISECTION_FACTOR = 32 @@ -409,6 +410,16 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun f"size: {table2.max_key-table1.min_key}" ) + # When benchmarking, we want the ability to skip checksumming. This + # allows us to download all rows for comparison in performance. By + # default, data-diff will checksum the section first (when it's below + # the threshold) and _then_ download it. + if BENCHMARK: + max_rows_from_keys = max(table1.max_key - table1.min_key, table2.max_key - table2.min_key) + if max_rows_from_keys < self.bisection_threshold: + yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max_rows_from_keys) + return + (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) if count1 == 0 and count2 == 0: diff --git a/tests/common.py b/tests/common.py index d042146b..2e6d7443 100644 --- a/tests/common.py +++ b/tests/common.py @@ -3,6 +3,7 @@ from data_diff import databases as db import logging +import subprocess TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql" TEST_POSTGRESQL_CONN_STRING: str = None @@ -14,6 +15,12 @@ DEFAULT_N_SAMPLES = 50 N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES)) +BENCHMARK = os.environ.get("BENCHMARK", False) + +def get_git_revision_short_hash() -> str: + return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() + +GIT_REVISION=get_git_revision_short_hash() level = logging.ERROR if os.environ.get("LOG_LEVEL", False): diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 29c814e2..ea5d0a28 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,16 +1,19 @@ from contextlib import suppress import unittest import time +import json import re +import rich.progress import math import uuid from datetime import datetime, timedelta +import logging from decimal import Decimal from parameterized import parameterized from data_diff import databases as db -from data_diff.diff_tables import TableDiffer, TableSegment -from .common import CONN_STRINGS, N_SAMPLES +from data_diff.diff_tables import TableDiffer, TableSegment, DEFAULT_BISECTION_THRESHOLD +from .common import CONN_STRINGS, N_SAMPLES, BENCHMARK, GIT_REVISION CONNS = {k: db.connect_to_uri(v, 1) for k, v in CONN_STRINGS.items()} @@ -172,7 +175,7 @@ def __iter__(self): TYPE_SAMPLES = { "int": IntFaker(N_SAMPLES), - "datetime_no_timezone": DateTimeFaker(N_SAMPLES), + "datetime": DateTimeFaker(N_SAMPLES), "float": FloatFaker(N_SAMPLES), "uuid": UUID_Faker(N_SAMPLES), } @@ -186,7 +189,7 @@ def __iter__(self): "bigint", # 8 bytes ], # https://www.postgresql.org/docs/current/datatype-datetime.html - "datetime_no_timezone": [ + "datetime": [ "timestamp(6) without time zone", "timestamp(3) without time zone", "timestamp(0) without time zone", @@ -214,7 +217,7 @@ def __iter__(self): "bigint", # 8 bytes ], # https://dev.mysql.com/doc/refman/8.0/en/datetime.html - "datetime_no_timezone": [ + "datetime": [ "timestamp(6)", "timestamp(3)", "timestamp(0)", @@ -235,7 +238,7 @@ def __iter__(self): }, db.BigQuery: { "int": ["int"], - "datetime_no_timezone": [ + "datetime": [ "timestamp", # "datetime", ], @@ -260,7 +263,7 @@ def __iter__(self): # "byteint" ], # https://docs.snowflake.com/en/sql-reference/data-types-datetime.html - "datetime_no_timezone": [ + "datetime": [ "timestamp(0)", "timestamp(3)", "timestamp(6)", @@ -280,7 +283,7 @@ def __iter__(self): "int": [ "int", ], - "datetime_no_timezone": [ + "datetime": [ "TIMESTAMP", ], # https://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html#r_Numeric_types201-floating-point-types @@ -299,7 +302,7 @@ def __iter__(self): "int": [ "int", ], - "datetime_no_timezone": [ + "datetime": [ "timestamp with local time zone", "timestamp(6) with local time zone", "timestamp(9) with local time zone", @@ -325,7 +328,7 @@ def __iter__(self): "int", # 4 bytes "bigint", # 8 bytes ], - "datetime_no_timezone": [ + "datetime": [ "timestamp", "timestamp with time zone", ], @@ -372,6 +375,7 @@ def sanitize(name): name = name.replace(r"with time zone", "y_tz") name = name.replace(r"with local time zone", "y_tz") name = name.replace(r"timestamp", "ts") + name = name.replace(r"double precision", "double") return parameterized.to_safe_name(name) @@ -402,35 +406,89 @@ def expand_params(testcase_func, param_num, param): return name -def _insert_to_table(conn, table, values): - insertion_query = f"INSERT INTO {table} (id, col) " +def _insert_to_table(conn, table, values, type): + current_n_rows = conn.query(f"SELECT COUNT(*) FROM {table}", int) + if current_n_rows == N_SAMPLES: + assert BENCHMARK, "Table should've been deleted, or we should be in BENCHMARK mode" + return + elif current_n_rows > 0: + _drop_table_if_exists(conn, table) + _create_table_with_indexes(conn, table, type) + if BENCHMARK and N_SAMPLES > 10_000: + description = f"{conn.name}: {table}" + values = rich.progress.track(values, total=N_SAMPLES, description=description) + + default_insertion_query = f"INSERT INTO {table} (id, col) VALUES " if isinstance(conn, db.Oracle): - selects = [] - for j, sample in values: - if isinstance(sample, (float, Decimal, int)): - value = str(sample) - elif isinstance(sample, datetime): - value = f"timestamp '{sample}'" - else: - value = f"'{sample}'" + default_insertion_query = f"INSERT INTO {table} (id, col)" + + insertion_query = default_insertion_query + selects = [] + for j, sample in values: + if isinstance(sample, (float, Decimal, int)): + value = str(sample) + elif isinstance(sample, datetime) and isinstance(conn, (db.Presto, db.Oracle)): + value = f"timestamp '{sample}'" + else: + value = f"'{sample}'" + + if isinstance(conn, db.Oracle): selects.append(f"SELECT {j}, {value} FROM dual") - insertion_query += " UNION ALL ".join(selects) - else: - insertion_query += " VALUES " - for j, sample in values: - if isinstance(sample, (float, Decimal, int)): - value = str(sample) - elif isinstance(sample, datetime) and isinstance(conn, db.Presto): - value = f"timestamp '{sample}'" - else: - value = f"'{sample}'" + else: insertion_query += f"({j}, {value})," - insertion_query = insertion_query[0:-1] + # Some databases want small batch sizes... + # Need to also insert on the last row, might not divide cleanly! + if j % 8000 == 0 or j == N_SAMPLES: + if isinstance(conn, db.Oracle): + insertion_query += " UNION ALL ".join(selects) + conn.query(insertion_query, None) + selects = [] + else: + conn.query(insertion_query[0:-1], None) + insertion_query = default_insertion_query + + if not isinstance(conn, db.BigQuery): + conn.query("COMMIT", None) + + +def _create_indexes(conn, table): + # It is unfortunate that Presto doesn't support creating indexes... + # Technically we could create it in the backing Postgres behind the scenes. + if isinstance(conn, (db.Snowflake, db.Redshift, db.Presto, db.BigQuery)): + return + + try: + if_not_exists = "IF NOT EXISTS" if not isinstance(conn, (db.MySQL, db.Oracle)) else "" + conn.query( + f"CREATE INDEX {if_not_exists} idx_{table[1:-1]}_id_col ON {table} (id, col)", + None, + ) + conn.query( + f"CREATE INDEX {if_not_exists} idx_{table[1:-1]}_id ON {table} (id)", + None, + ) + except Exception as err: + if "Duplicate key name" in str(err): # mysql + pass + elif "such column list already indexed" in str(err): # oracle + pass + elif "name is already used" in str(err): # oracle + pass + else: + raise (err) - conn.query(insertion_query, None) +def _create_table_with_indexes(conn, table, type): + if isinstance(conn, db.Oracle): + already_exists = conn.query(f"SELECT COUNT(*) from tab where tname='{table.upper()}'", int) > 0 + if not already_exists: + conn.query(f"CREATE TABLE {table}(id int, col {type})", None) + else: + conn.query(f"CREATE TABLE IF NOT EXISTS {table}(id int, col {type})", None) + + _create_indexes(conn, table) if not isinstance(conn, db.BigQuery): conn.query("COMMIT", None) @@ -447,6 +505,8 @@ def _drop_table_if_exists(conn, table): class TestDiffCrossDatabaseTables(unittest.TestCase): + maxDiff = 10000 + @parameterized.expand(type_pairs, name_func=expand_params) def test_types(self, source_db, target_db, source_type, target_type, type_category): start = time.time() @@ -466,39 +526,106 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego src_table = src_conn.quote(".".join(src_table_path)) dst_table = dst_conn.quote(".".join(dst_table_path)) + start = time.time() _drop_table_if_exists(src_conn, src_table) - src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type})", None) - _insert_to_table(src_conn, src_table, enumerate(sample_values, 1)) + _create_table_with_indexes(src_conn, src_table, source_type) + _insert_to_table(src_conn, src_table, enumerate(sample_values, 1), source_type) + insertion_source_duration = time.time() - start values_in_source = PaginatedTable(src_table, src_conn) if source_db is db.Presto: if source_type.startswith("decimal"): - values_in_source = [(a, Decimal(b)) for a, b in values_in_source] + values_in_source = ((a, Decimal(b)) for a, b in values_in_source) elif source_type.startswith("timestamp"): - values_in_source = [(a, datetime.fromisoformat(b.rstrip(" UTC"))) for a, b in values_in_source] + values_in_source = ((a, datetime.fromisoformat(b.rstrip(" UTC"))) for a, b in values_in_source) + start = time.time() _drop_table_if_exists(dst_conn, dst_table) - dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None) - _insert_to_table(dst_conn, dst_table, values_in_source) + _create_table_with_indexes(dst_conn, dst_table, target_type) + _insert_to_table(dst_conn, dst_table, values_in_source, target_type) + insertion_target_duration = time.time() - start self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False) self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False) - self.assertEqual(len(sample_values), self.table.count()) - self.assertEqual(len(sample_values), self.table2.count()) + start = time.time() + self.assertEqual(N_SAMPLES, self.table.count()) + count_source_duration = time.time() - start - differ = TableDiffer(bisection_threshold=3, bisection_factor=2) # ensure we actually checksum + start = time.time() + self.assertEqual(N_SAMPLES, self.table2.count()) + count_target_duration = time.time() - start + + # When testing, we configure these to their lowest possible values for + # the DEFAULT_N_SAMPLES. + # When benchmarking, we try to dynamically create some more optimal + # configuration with each segment being ~250k rows. + ch_factor = min(max(int(N_SAMPLES / 250_000), 2), 128) if BENCHMARK else 2 + ch_threshold = min(DEFAULT_BISECTION_THRESHOLD, int(N_SAMPLES / ch_factor)) if BENCHMARK else 3 + ch_threads = 1 + differ = TableDiffer( + bisection_threshold=ch_threshold, + bisection_factor=ch_factor, + max_threadpool_size=ch_threads, + ) + start = time.time() diff = list(differ.diff_tables(self.table, self.table2)) + checksum_duration = time.time() - start expected = [] self.assertEqual(expected, diff) self.assertEqual(0, differ.stats.get("rows_downloaded", 0)) - # Ensure that Python agrees with the checksum! - differ = TableDiffer(bisection_threshold=1000000000) + # This section downloads all rows to ensure that Python agrees with the + # database, in terms of comparison. + # + # For benchmarking, to make it fair, we split into segments of a + # reasonable amount of rows each. These will then be downloaded in + # parallel, using the existing implementation. + dl_factor = max(int(N_SAMPLES / 100_000), 2) if BENCHMARK else 2 + dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else N_SAMPLES + 1 + dl_threads = 1 + differ = TableDiffer( + bisection_threshold=dl_threshold, bisection_factor=dl_factor, max_threadpool_size=dl_threads + ) + start = time.time() diff = list(differ.diff_tables(self.table, self.table2)) + download_duration = time.time() - start expected = [] self.assertEqual(expected, diff) self.assertEqual(len(sample_values), differ.stats.get("rows_downloaded", 0)) - duration = time.time() - start - # print(f"source_db={source_db.__name__} target_db={target_db.__name__} source_type={source_type} target_type={target_type} duration={round(duration * 1000, 2)}ms") + result = { + "test": self._testMethodName, + "source_db": source_db.__name__, + "target_db": target_db.__name__, + "date": str(datetime.today()), + "git_revision": GIT_REVISION, + "rows": N_SAMPLES, + "rows_human": number_to_human(N_SAMPLES), + "src_table": src_table[1:-1], # remove quotes + "target_table": dst_table[1:-1], + "source_type": source_type, + "target_type": target_type, + "insertion_source_sec": round(insertion_source_duration, 3), + "insertion_target_sec": round(insertion_target_duration, 3), + "count_source_sec": round(count_source_duration, 3), + "count_target_sec": round(count_target_duration, 3), + "checksum_sec": round(checksum_duration, 3), + "download_sec": round(download_duration, 3), + "download_bisection_factor": dl_factor, + "download_bisection_threshold": dl_threshold, + "download_threads": dl_threads, + "checksum_bisection_factor": ch_factor, + "checksum_bisection_threshold": ch_threshold, + "checksum_threads": ch_threads, + } + + if BENCHMARK: + print(json.dumps(result, indent=2)) + file_name = f"benchmark_{GIT_REVISION}.jsonl" + with open(file_name, "a") as file: + file.write(json.dumps(result) + "\n") + file.flush() + print(f"Written to {file_name}") + else: + logging.debug(json.dumps(result, indent=2))