diff --git a/data_diff/__main__.py b/data_diff/__main__.py index b5038b08..421f1101 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -10,6 +10,7 @@ import click import rich +from pydantic_core._pydantic_core import ValidationError from rich.logging import RichHandler from data_diff import Database, DbPath @@ -243,7 +244,6 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - "'serial' guarantees a single-threaded execution of the algorithm (useful for debugging)." ), metavar="COUNT", - type=int, ) @click.option( "-w", @@ -494,38 +494,12 @@ def _get_expanded_columns( return expanded_columns -def _set_threads(cli_options: CliOptions) -> None: - cli_options.threaded = True - message = "Error: threads must be of type int, or value must be 'serial'." - if cli_options.threads is None: - cli_options.threads = 1 - - elif isinstance(cli_options.threads, str): - if cli_options.threads.lower() != "serial": - logging.error(message) - raise ValueError(message) - - assert not (cli_options.threads1 or cli_options.threads2) - cli_options.threaded = False - cli_options.threads = 1 - - elif not isinstance(cli_options.threads, int): - logging.error(message) - raise ValueError(message) - - elif cli_options.threads < 1: - message = "Error: threads must be >= 1" - logging.error(message) - raise ValueError(message) - - def _data_diff(cli_options: CliOptions) -> None: if cli_options.limit and cli_options.stats: logging.error("Cannot specify a limit when using the -s/--stats switch") return key_columns = cli_options.key_columns or ("id",) - _set_threads(cli_options) start = time.monotonic() if cli_options.database1 is None or cli_options.database2 is None: diff --git a/data_diff/cli_options.py b/data_diff/cli_options.py index 361ead75..01017500 100644 --- a/data_diff/cli_options.py +++ b/data_diff/cli_options.py @@ -1,20 +1,20 @@ -from dataclasses import dataclass from typing import Optional, Literal, Dict, Union, Tuple +from pydantic import BaseModel, PositiveInt, model_validator -@dataclass -class CliOptions: - bisection_factor: int - bisection_threshold: int - table_write_limit: int + +class CliOptions(BaseModel): + bisection_factor: PositiveInt + bisection_threshold: PositiveInt + table_write_limit: PositiveInt database1: Union[str, Dict, None] = None table1: Optional[str] = None database2: Union[str, Dict, None] = None table2: Optional[str] = None key_columns: Tuple[str] = () update_column: Optional[str] = None - columns: Tuple[str] = () - limit: Optional[int] = None + columns: Tuple[str, ...] = () + limit: Optional[PositiveInt] = None materialize_to_table: Optional[str] = None min_age: Optional[str] = None max_age: Optional[str] = None @@ -29,9 +29,9 @@ class CliOptions: assume_unique_key: bool = False sample_exclusive_rows: bool = False materialize_all_rows: bool = False - threads: Union[None, int, Literal["serial"]] = None - threads1: Optional[int] = None - threads2: Optional[int] = None + threads: Union[PositiveInt, Literal["serial"]] = 1 + threads1: Optional[PositiveInt] = None + threads2: Optional[PositiveInt] = None threaded: bool = False where: Optional[str] = None algorithm: Literal["auto", "joindiff", "hashdiff"] = None @@ -46,3 +46,13 @@ class CliOptions: prod_database: Optional[str] = None prod_schema: Optional[str] = None __conf__: Optional[Dict] = None + + @model_validator(mode="after") + def check_threads(self) -> "CliOptions": + self.threaded = True + if self.threads == "serial": + if self.threads1 or self.threads2: + raise ValueError("threads1 and threads2 can not be set when threads is set to serial.") + self.threads = 1 + self.threaded = False + return self diff --git a/data_diff/cloud/data_source.py b/data_diff/cloud/data_source.py index 3f3b2e16..8d41def6 100644 --- a/data_diff/cloud/data_source.py +++ b/data_diff/cloud/data_source.py @@ -2,7 +2,7 @@ import time from typing import List, Optional, Union, overload -import pydantic +from pydantic import BaseModel import rich from rich.table import Table from rich.prompt import Confirm, Prompt, FloatPrompt, IntPrompt, InvalidResponse @@ -21,7 +21,7 @@ UNKNOWN_VALUE = "unknown_value" -class TDataSourceTestStage(pydantic.BaseModel): +class TDataSourceTestStage(BaseModel): name: str status: TestDataSourceStatus description: str = "" diff --git a/data_diff/cloud/datafold_api.py b/data_diff/cloud/datafold_api.py index f6548abb..1158e38d 100644 --- a/data_diff/cloud/datafold_api.py +++ b/data_diff/cloud/datafold_api.py @@ -1,11 +1,11 @@ import base64 import enum import time -from typing import Any, Dict, List, Optional, Type, Tuple +from typing import Any, Dict, List, Optional, Tuple import attrs -import pydantic import requests +from pydantic import BaseModel from typing_extensions import Self from data_diff.errors import DataDiffCloudDiffFailed, DataDiffCloudDiffTimedOut, DataDiffDatasourceIdNotFoundError @@ -21,7 +21,7 @@ class TestDataSourceStatus(str, enum.Enum): UNKNOWN = "unknown" -class TCloudApiDataSourceSchema(pydantic.BaseModel): +class TCloudApiDataSourceSchema(BaseModel): title: str properties: Dict[str, Dict[str, Any]] required: List[str] @@ -48,13 +48,13 @@ def from_orm(cls, obj: Any) -> Self: ) -class TCloudApiDataSourceConfigSchema(pydantic.BaseModel): +class TCloudApiDataSourceConfigSchema(BaseModel): name: str db_type: str config_schema: TCloudApiDataSourceSchema -class TCloudApiDataSource(pydantic.BaseModel): +class TCloudApiDataSource(BaseModel): id: Optional[int] = None name: str type: str @@ -85,7 +85,7 @@ class TCloudApiDataSource(pydantic.BaseModel): secret_id: Optional[int] = None -class TDsConfig(pydantic.BaseModel): +class TDsConfig(BaseModel): name: str type: str temp_schema: str @@ -95,7 +95,7 @@ class TDsConfig(pydantic.BaseModel): disable_profiling: bool = True -class TCloudApiDataDiff(pydantic.BaseModel): +class TCloudApiDataDiff(BaseModel): data_source1_id: int data_source2_id: int table1: List[str] @@ -107,13 +107,13 @@ class TCloudApiDataDiff(pydantic.BaseModel): exclude_columns: Optional[List[str]] -class TCloudApiOrgMeta(pydantic.BaseModel): +class TCloudApiOrgMeta(BaseModel): org_id: int org_name: str user_id: int -class TSummaryResultPrimaryKeyStats(pydantic.BaseModel): +class TSummaryResultPrimaryKeyStats(BaseModel): total_rows: Tuple[int, int] nulls: Tuple[int, int] dupes: Tuple[int, int] @@ -121,12 +121,12 @@ class TSummaryResultPrimaryKeyStats(pydantic.BaseModel): distincts: Tuple[int, int] -class TSummaryResultColumnDiffStats(pydantic.BaseModel): +class TSummaryResultColumnDiffStats(BaseModel): column_name: str match: float -class TSummaryResultValueStats(pydantic.BaseModel): +class TSummaryResultValueStats(BaseModel): total_rows: int rows_with_differences: int total_values: int @@ -135,7 +135,7 @@ class TSummaryResultValueStats(pydantic.BaseModel): columns_diff_stats: List[TSummaryResultColumnDiffStats] -class TSummaryResultSchemaStats(pydantic.BaseModel): +class TSummaryResultSchemaStats(BaseModel): columns_mismatched: Tuple[int, int] column_type_mismatches: int column_reorders: int @@ -144,11 +144,11 @@ class TSummaryResultSchemaStats(pydantic.BaseModel): exclusive_columns: Tuple[List[str], List[str]] -class TSummaryResultDependencyDetails(pydantic.BaseModel): +class TSummaryResultDependencyDetails(BaseModel): deps: Dict[str, List[Dict]] -class TCloudApiDataDiffSummaryResult(pydantic.BaseModel): +class TCloudApiDataDiffSummaryResult(BaseModel): status: str pks: Optional[TSummaryResultPrimaryKeyStats] values: Optional[TSummaryResultValueStats] @@ -170,13 +170,13 @@ def from_orm(cls, obj: Any) -> Self: ) -class TCloudDataSourceTestResult(pydantic.BaseModel): +class TCloudDataSourceTestResult(BaseModel): status: TestDataSourceStatus message: str outcome: str -class TCloudApiDataSourceTestResult(pydantic.BaseModel): +class TCloudApiDataSourceTestResult(BaseModel): name: str status: str result: Optional[TCloudDataSourceTestResult] diff --git a/data_diff/dbt.py b/data_diff/dbt.py index 9fe19e74..dc486038 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -1,28 +1,28 @@ -from contextlib import nullcontext import json import os import re import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from contextlib import nullcontext from typing import List, Optional, Dict, Tuple, Union + import keyring -import pydantic import rich -from rich.prompt import Prompt +from pydantic import BaseModel from rich.markdown import Markdown -from concurrent.futures import ThreadPoolExecutor, as_completed +from rich.prompt import Prompt +from data_diff import connect_to_table, diff_tables, Algorithm from data_diff.cli_options import CliOptions +from data_diff.cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta +from data_diff.dbt_parser import DbtParser, TDatadiffConfig +from data_diff.diff_tables import DiffResultWrapper from data_diff.errors import ( DataDiffCustomSchemaNoConfigError, DataDiffDbtProjectVarsNotFoundError, DataDiffNoAPIKeyError, DataDiffNoDatasourceIdError, ) - -from data_diff import connect_to_table, diff_tables, Algorithm -from data_diff.cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta -from data_diff.dbt_parser import DbtParser, TDatadiffConfig -from data_diff.diff_tables import DiffResultWrapper from data_diff.format import jsonify, jsonify_error from data_diff.tracking import ( bool_ask_for_email, @@ -56,7 +56,7 @@ DATAFOLD_INSTRUCTIONS_URL = "https://docs.datafold.com/development_testing/datafold_cloud" -class TDiffVars(pydantic.BaseModel): +class TDiffVars(BaseModel): dev_path: List[str] prod_path: List[str] primary_keys: List[str] diff --git a/data_diff/dbt_config_validators.py b/data_diff/dbt_config_validators.py index e7c548a4..4606bcab 100644 --- a/data_diff/dbt_config_validators.py +++ b/data_diff/dbt_config_validators.py @@ -5,7 +5,7 @@ class ManifestJsonConfig(BaseModel): class Metadata(BaseModel): - dbt_version: str = Field(..., regex=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$") + dbt_version: str = Field(..., pattern=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$") project_id: Optional[str] user_id: Optional[str] @@ -46,7 +46,7 @@ class DependsOn(BaseModel): class RunResultsJsonConfig(BaseModel): class Metadata(BaseModel): - dbt_version: str = Field(..., regex=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$") + dbt_version: str = Field(..., pattern=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$") class Results(BaseModel): class Status(Enum): diff --git a/pyproject.toml b/pyproject.toml index 13663d8d..cb81fd24 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ ] packages = [{ include = "data_diff" }] [tool.poetry.dependencies] -pydantic = "1.10.12" +pydantic = "^2.6.0" python = "^3.8.0" dsnparse = "<0.2.0" click = "^8.1" diff --git a/tests/common.py b/tests/common.py index 4702430f..913d115a 100644 --- a/tests/common.py +++ b/tests/common.py @@ -187,5 +187,5 @@ def ansi_stdout_cleanup(ansi_input) -> str: return re.sub(r"\x1B\[[0-?]*[ -/]*[@-~]", "", ansi_input) -def get_cli_options() -> CliOptions: - return CliOptions(bisection_factor=2, bisection_threshold=3, table_write_limit=1) +def get_cli_options(**kwargs) -> CliOptions: + return CliOptions(bisection_factor=2, bisection_threshold=3, table_write_limit=1, **kwargs) diff --git a/tests/test_main.py b/tests/test_main.py index 60037345..9f63411b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,10 @@ import unittest +from pydantic_core._pydantic_core import PydanticCustomError, ValidationError + from data_diff import Database, JoinDiffer, HashDiffer from data_diff import databases as db -from data_diff.__main__ import _get_dbs, _set_age, _get_table_differ, _get_expanded_columns, _set_threads +from data_diff.__main__ import _get_dbs, _set_age, _get_table_differ, _get_expanded_columns from data_diff.cli_options import CliOptions from data_diff.databases.mysql import MySQL from data_diff.diff_tables import TableDiffer @@ -15,13 +17,14 @@ def test__get_dbs(self) -> None: db2: Database # no threads and 2 threads1 - cli_options: CliOptions = get_cli_options() - cli_options.database1 = CONN_STRINGS[db.PostgreSQL] - cli_options.database2 = CONN_STRINGS[db.PostgreSQL] - cli_options.threads = 0 - cli_options.threads1 = 2 - cli_options.threads2 = 0 - cli_options.interactive = False + cli_options: CliOptions = get_cli_options( + database1=CONN_STRINGS[db.PostgreSQL], + database2=CONN_STRINGS[db.PostgreSQL], + threads=1, + threads1=2, + threads2=None, + interactive=False, + ) db1, db2 = _get_dbs(cli_options) with db1, db2: assert db1 == db2 @@ -256,64 +259,38 @@ def test__get_expanded_columns_case_sensitive(self): class TestGetThreads(unittest.TestCase): def test__get_threads(self): - cli_options: CliOptions = get_cli_options() - cli_options.thread1 = None - cli_options.threads2 = None - _set_threads(cli_options) + cli_options: CliOptions = get_cli_options(thread1=None, threads2=None) assert cli_options.threaded assert cli_options.threads == 1 - cli_options.threads1 = 2 - cli_options.threads2 = 3 - _set_threads(cli_options) + cli_options: CliOptions = get_cli_options(thread1=2, threads2=3) assert cli_options.threaded assert cli_options.threads == 1 - cli_options.threads = "serial" - cli_options.threads1 = None - cli_options.threads2 = None - _set_threads(cli_options) + cli_options: CliOptions = get_cli_options(threads="serial", thread1=None, threads2=None) assert not cli_options.threaded assert cli_options.threads == 1 - cli_options.threads = "serial" - cli_options.threads1 = 1 - cli_options.threads2 = 2 - with self.assertRaises(AssertionError): - _set_threads(cli_options) - - cli_options.threads1 = None - cli_options.threads2 = None - cli_options.threads = "4" - with self.assertRaises(ValueError) as value_error: - _set_threads(cli_options) - assert str(value_error.exception) == "Error: threads must be of type int, or value must be 'serial'." - - cli_options.threads = "auto" - with self.assertRaises(ValueError) as value_error: - _set_threads(cli_options) - assert str(value_error.exception) == "Error: threads must be of type int, or value must be 'serial'." - - cli_options.threads = 5 - _set_threads(cli_options) + with self.assertRaises(ValueError): + get_cli_options(threads="serial", thread1=1, threads2=2) + + with self.assertRaises(ValidationError): + get_cli_options(threads="auto", thread1=None, threads2=None) + + cli_options: CliOptions = get_cli_options(threads="4", thread1=None, threads2=None) + assert cli_options.threaded + assert cli_options.threads == 4 + + cli_options: CliOptions = get_cli_options(threads=5, thread1=None, threads2=None) assert cli_options.threaded assert cli_options.threads == 5 - cli_options.threads = 6 - cli_options.threads1 = 7 - cli_options.threads2 = 8 - _set_threads(cli_options) + cli_options: CliOptions = get_cli_options(threads=6, thread1=7, threads2=8) assert cli_options.threaded assert cli_options.threads == 6 - cli_options.threads = 0 - cli_options.threads1 = None - cli_options.threads2 = None - with self.assertRaises(ValueError) as value_error: - _set_threads(cli_options) - assert str(value_error.exception) == "Error: threads must be >= 1" - - cli_options.threads = -1 - with self.assertRaises(ValueError) as value_error: - _set_threads(cli_options) - assert str(value_error.exception) == "Error: threads must be >= 1" + with self.assertRaises(ValidationError): + get_cli_options(threads=0, thread1=None, threads2=None) + + with self.assertRaises(ValidationError): + get_cli_options(threads=-1, thread1=None, threads2=None)