Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 965ead0

Browse files
committed
adding pydantic
Signed-off-by: Sarad Mohanan <[email protected]>
1 parent bd075c5 commit 965ead0

File tree

9 files changed

+86
-125
lines changed

9 files changed

+86
-125
lines changed

data_diff/__main__.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import click
1212
import rich
13+
from pydantic_core._pydantic_core import ValidationError
1314
from rich.logging import RichHandler
1415

1516
from data_diff import Database, DbPath
@@ -243,7 +244,6 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
243244
"'serial' guarantees a single-threaded execution of the algorithm (useful for debugging)."
244245
),
245246
metavar="COUNT",
246-
type=int,
247247
)
248248
@click.option(
249249
"-w",
@@ -494,38 +494,12 @@ def _get_expanded_columns(
494494
return expanded_columns
495495

496496

497-
def _set_threads(cli_options: CliOptions) -> None:
498-
cli_options.threaded = True
499-
message = "Error: threads must be of type int, or value must be 'serial'."
500-
if cli_options.threads is None:
501-
cli_options.threads = 1
502-
503-
elif isinstance(cli_options.threads, str):
504-
if cli_options.threads.lower() != "serial":
505-
logging.error(message)
506-
raise ValueError(message)
507-
508-
assert not (cli_options.threads1 or cli_options.threads2)
509-
cli_options.threaded = False
510-
cli_options.threads = 1
511-
512-
elif not isinstance(cli_options.threads, int):
513-
logging.error(message)
514-
raise ValueError(message)
515-
516-
elif cli_options.threads < 1:
517-
message = "Error: threads must be >= 1"
518-
logging.error(message)
519-
raise ValueError(message)
520-
521-
522497
def _data_diff(cli_options: CliOptions) -> None:
523498
if cli_options.limit and cli_options.stats:
524499
logging.error("Cannot specify a limit when using the -s/--stats switch")
525500
return
526501

527502
key_columns = cli_options.key_columns or ("id",)
528-
_set_threads(cli_options)
529503
start = time.monotonic()
530504

531505
if cli_options.database1 is None or cli_options.database2 is None:

data_diff/cli_options.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
from dataclasses import dataclass
21
from typing import Optional, Literal, Dict, Union, Tuple
32

3+
from pydantic import BaseModel, PositiveInt, model_validator
44

5-
@dataclass
6-
class CliOptions:
7-
bisection_factor: int
8-
bisection_threshold: int
9-
table_write_limit: int
5+
6+
class CliOptions(BaseModel):
7+
bisection_factor: PositiveInt
8+
bisection_threshold: PositiveInt
9+
table_write_limit: PositiveInt
1010
database1: Union[str, Dict, None] = None
1111
table1: Optional[str] = None
1212
database2: Union[str, Dict, None] = None
1313
table2: Optional[str] = None
1414
key_columns: Tuple[str] = ()
1515
update_column: Optional[str] = None
16-
columns: Tuple[str] = ()
17-
limit: Optional[int] = None
16+
columns: Tuple[str, ...] = ()
17+
limit: Optional[PositiveInt] = None
1818
materialize_to_table: Optional[str] = None
1919
min_age: Optional[str] = None
2020
max_age: Optional[str] = None
@@ -29,9 +29,9 @@ class CliOptions:
2929
assume_unique_key: bool = False
3030
sample_exclusive_rows: bool = False
3131
materialize_all_rows: bool = False
32-
threads: Union[None, int, Literal["serial"]] = None
33-
threads1: Optional[int] = None
34-
threads2: Optional[int] = None
32+
threads: Union[PositiveInt, Literal["serial"]] = 1
33+
threads1: Optional[PositiveInt] = None
34+
threads2: Optional[PositiveInt] = None
3535
threaded: bool = False
3636
where: Optional[str] = None
3737
algorithm: Literal["auto", "joindiff", "hashdiff"] = None
@@ -46,3 +46,13 @@ class CliOptions:
4646
prod_database: Optional[str] = None
4747
prod_schema: Optional[str] = None
4848
__conf__: Optional[Dict] = None
49+
50+
@model_validator(mode="after")
51+
def check_threads(self) -> "CliOptions":
52+
self.threaded = True
53+
if self.threads == "serial":
54+
if self.threads1 or self.threads2:
55+
raise ValueError("threads1 and threads2 can not be set when threads is set to serial.")
56+
self.threads = 1
57+
self.threaded = False
58+
return self

data_diff/cloud/data_source.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
from typing import List, Optional, Union, overload
44

5-
import pydantic
5+
from pydantic import BaseModel
66
import rich
77
from rich.table import Table
88
from rich.prompt import Confirm, Prompt, FloatPrompt, IntPrompt, InvalidResponse
@@ -21,7 +21,7 @@
2121
UNKNOWN_VALUE = "unknown_value"
2222

2323

24-
class TDataSourceTestStage(pydantic.BaseModel):
24+
class TDataSourceTestStage(BaseModel):
2525
name: str
2626
status: TestDataSourceStatus
2727
description: str = ""

data_diff/cloud/datafold_api.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import base64
22
import enum
33
import time
4-
from typing import Any, Dict, List, Optional, Type, Tuple
4+
from typing import Any, Dict, List, Optional, Tuple
55

66
import attrs
7-
import pydantic
87
import requests
8+
from pydantic import BaseModel
99
from typing_extensions import Self
1010

1111
from data_diff.errors import DataDiffCloudDiffFailed, DataDiffCloudDiffTimedOut, DataDiffDatasourceIdNotFoundError
@@ -21,7 +21,7 @@ class TestDataSourceStatus(str, enum.Enum):
2121
UNKNOWN = "unknown"
2222

2323

24-
class TCloudApiDataSourceSchema(pydantic.BaseModel):
24+
class TCloudApiDataSourceSchema(BaseModel):
2525
title: str
2626
properties: Dict[str, Dict[str, Any]]
2727
required: List[str]
@@ -48,13 +48,13 @@ def from_orm(cls, obj: Any) -> Self:
4848
)
4949

5050

51-
class TCloudApiDataSourceConfigSchema(pydantic.BaseModel):
51+
class TCloudApiDataSourceConfigSchema(BaseModel):
5252
name: str
5353
db_type: str
5454
config_schema: TCloudApiDataSourceSchema
5555

5656

57-
class TCloudApiDataSource(pydantic.BaseModel):
57+
class TCloudApiDataSource(BaseModel):
5858
id: Optional[int] = None
5959
name: str
6060
type: str
@@ -85,7 +85,7 @@ class TCloudApiDataSource(pydantic.BaseModel):
8585
secret_id: Optional[int] = None
8686

8787

88-
class TDsConfig(pydantic.BaseModel):
88+
class TDsConfig(BaseModel):
8989
name: str
9090
type: str
9191
temp_schema: str
@@ -95,7 +95,7 @@ class TDsConfig(pydantic.BaseModel):
9595
disable_profiling: bool = True
9696

9797

98-
class TCloudApiDataDiff(pydantic.BaseModel):
98+
class TCloudApiDataDiff(BaseModel):
9999
data_source1_id: int
100100
data_source2_id: int
101101
table1: List[str]
@@ -107,26 +107,26 @@ class TCloudApiDataDiff(pydantic.BaseModel):
107107
exclude_columns: Optional[List[str]]
108108

109109

110-
class TCloudApiOrgMeta(pydantic.BaseModel):
110+
class TCloudApiOrgMeta(BaseModel):
111111
org_id: int
112112
org_name: str
113113
user_id: int
114114

115115

116-
class TSummaryResultPrimaryKeyStats(pydantic.BaseModel):
116+
class TSummaryResultPrimaryKeyStats(BaseModel):
117117
total_rows: Tuple[int, int]
118118
nulls: Tuple[int, int]
119119
dupes: Tuple[int, int]
120120
exclusives: Tuple[int, int]
121121
distincts: Tuple[int, int]
122122

123123

124-
class TSummaryResultColumnDiffStats(pydantic.BaseModel):
124+
class TSummaryResultColumnDiffStats(BaseModel):
125125
column_name: str
126126
match: float
127127

128128

129-
class TSummaryResultValueStats(pydantic.BaseModel):
129+
class TSummaryResultValueStats(BaseModel):
130130
total_rows: int
131131
rows_with_differences: int
132132
total_values: int
@@ -135,7 +135,7 @@ class TSummaryResultValueStats(pydantic.BaseModel):
135135
columns_diff_stats: List[TSummaryResultColumnDiffStats]
136136

137137

138-
class TSummaryResultSchemaStats(pydantic.BaseModel):
138+
class TSummaryResultSchemaStats(BaseModel):
139139
columns_mismatched: Tuple[int, int]
140140
column_type_mismatches: int
141141
column_reorders: int
@@ -144,11 +144,11 @@ class TSummaryResultSchemaStats(pydantic.BaseModel):
144144
exclusive_columns: Tuple[List[str], List[str]]
145145

146146

147-
class TSummaryResultDependencyDetails(pydantic.BaseModel):
147+
class TSummaryResultDependencyDetails(BaseModel):
148148
deps: Dict[str, List[Dict]]
149149

150150

151-
class TCloudApiDataDiffSummaryResult(pydantic.BaseModel):
151+
class TCloudApiDataDiffSummaryResult(BaseModel):
152152
status: str
153153
pks: Optional[TSummaryResultPrimaryKeyStats]
154154
values: Optional[TSummaryResultValueStats]
@@ -170,13 +170,13 @@ def from_orm(cls, obj: Any) -> Self:
170170
)
171171

172172

173-
class TCloudDataSourceTestResult(pydantic.BaseModel):
173+
class TCloudDataSourceTestResult(BaseModel):
174174
status: TestDataSourceStatus
175175
message: str
176176
outcome: str
177177

178178

179-
class TCloudApiDataSourceTestResult(pydantic.BaseModel):
179+
class TCloudApiDataSourceTestResult(BaseModel):
180180
name: str
181181
status: str
182182
result: Optional[TCloudDataSourceTestResult]

data_diff/dbt.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1-
from contextlib import nullcontext
21
import json
32
import os
43
import re
54
import time
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
6+
from contextlib import nullcontext
67
from typing import List, Optional, Dict, Tuple, Union
8+
79
import keyring
8-
import pydantic
910
import rich
10-
from rich.prompt import Prompt
11+
from pydantic import BaseModel
1112
from rich.markdown import Markdown
12-
from concurrent.futures import ThreadPoolExecutor, as_completed
13+
from rich.prompt import Prompt
1314

15+
from data_diff import connect_to_table, diff_tables, Algorithm
1416
from data_diff.cli_options import CliOptions
17+
from data_diff.cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta
18+
from data_diff.dbt_parser import DbtParser, TDatadiffConfig
19+
from data_diff.diff_tables import DiffResultWrapper
1520
from data_diff.errors import (
1621
DataDiffCustomSchemaNoConfigError,
1722
DataDiffDbtProjectVarsNotFoundError,
1823
DataDiffNoAPIKeyError,
1924
DataDiffNoDatasourceIdError,
2025
)
21-
22-
from data_diff import connect_to_table, diff_tables, Algorithm
23-
from data_diff.cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta
24-
from data_diff.dbt_parser import DbtParser, TDatadiffConfig
25-
from data_diff.diff_tables import DiffResultWrapper
2626
from data_diff.format import jsonify, jsonify_error
2727
from data_diff.tracking import (
2828
bool_ask_for_email,
@@ -56,7 +56,7 @@
5656
DATAFOLD_INSTRUCTIONS_URL = "https://docs.datafold.com/development_testing/datafold_cloud"
5757

5858

59-
class TDiffVars(pydantic.BaseModel):
59+
class TDiffVars(BaseModel):
6060
dev_path: List[str]
6161
prod_path: List[str]
6262
primary_keys: List[str]

data_diff/dbt_config_validators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class ManifestJsonConfig(BaseModel):
77
class Metadata(BaseModel):
8-
dbt_version: str = Field(..., regex=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$")
8+
dbt_version: str = Field(..., pattern=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$")
99
project_id: Optional[str]
1010
user_id: Optional[str]
1111

@@ -46,7 +46,7 @@ class DependsOn(BaseModel):
4646

4747
class RunResultsJsonConfig(BaseModel):
4848
class Metadata(BaseModel):
49-
dbt_version: str = Field(..., regex=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$")
49+
dbt_version: str = Field(..., pattern=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$")
5050

5151
class Results(BaseModel):
5252
class Status(Enum):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ classifiers = [
2121
]
2222
packages = [{ include = "data_diff" }]
2323
[tool.poetry.dependencies]
24-
pydantic = "1.10.12"
24+
pydantic = "^2.6.0"
2525
python = "^3.8.0"
2626
dsnparse = "<0.2.0"
2727
click = "^8.1"

tests/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,5 +187,5 @@ def ansi_stdout_cleanup(ansi_input) -> str:
187187
return re.sub(r"\x1B\[[0-?]*[ -/]*[@-~]", "", ansi_input)
188188

189189

190-
def get_cli_options() -> CliOptions:
191-
return CliOptions(bisection_factor=2, bisection_threshold=3, table_write_limit=1)
190+
def get_cli_options(**kwargs) -> CliOptions:
191+
return CliOptions(bisection_factor=2, bisection_threshold=3, table_write_limit=1, **kwargs)

0 commit comments

Comments
 (0)