Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
adding pydantic
Browse files Browse the repository at this point in the history
Signed-off-by: Sarad Mohanan <[email protected]>
  • Loading branch information
sar009 committed Feb 8, 2024
1 parent bd075c5 commit 965ead0
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 125 deletions.
28 changes: 1 addition & 27 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import click
import rich
from pydantic_core._pydantic_core import ValidationError
from rich.logging import RichHandler

from data_diff import Database, DbPath
Expand Down Expand Up @@ -243,7 +244,6 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
"'serial' guarantees a single-threaded execution of the algorithm (useful for debugging)."
),
metavar="COUNT",
type=int,
)
@click.option(
"-w",
Expand Down Expand Up @@ -494,38 +494,12 @@ def _get_expanded_columns(
return expanded_columns


def _set_threads(cli_options: CliOptions) -> None:
cli_options.threaded = True
message = "Error: threads must be of type int, or value must be 'serial'."
if cli_options.threads is None:
cli_options.threads = 1

elif isinstance(cli_options.threads, str):
if cli_options.threads.lower() != "serial":
logging.error(message)
raise ValueError(message)

assert not (cli_options.threads1 or cli_options.threads2)
cli_options.threaded = False
cli_options.threads = 1

elif not isinstance(cli_options.threads, int):
logging.error(message)
raise ValueError(message)

elif cli_options.threads < 1:
message = "Error: threads must be >= 1"
logging.error(message)
raise ValueError(message)


def _data_diff(cli_options: CliOptions) -> None:
if cli_options.limit and cli_options.stats:
logging.error("Cannot specify a limit when using the -s/--stats switch")
return

key_columns = cli_options.key_columns or ("id",)
_set_threads(cli_options)
start = time.monotonic()

if cli_options.database1 is None or cli_options.database2 is None:
Expand Down
32 changes: 21 additions & 11 deletions data_diff/cli_options.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from dataclasses import dataclass
from typing import Optional, Literal, Dict, Union, Tuple

from pydantic import BaseModel, PositiveInt, model_validator

@dataclass
class CliOptions:
bisection_factor: int
bisection_threshold: int
table_write_limit: int

class CliOptions(BaseModel):
bisection_factor: PositiveInt
bisection_threshold: PositiveInt
table_write_limit: PositiveInt
database1: Union[str, Dict, None] = None
table1: Optional[str] = None
database2: Union[str, Dict, None] = None
table2: Optional[str] = None
key_columns: Tuple[str] = ()
update_column: Optional[str] = None
columns: Tuple[str] = ()
limit: Optional[int] = None
columns: Tuple[str, ...] = ()
limit: Optional[PositiveInt] = None
materialize_to_table: Optional[str] = None
min_age: Optional[str] = None
max_age: Optional[str] = None
Expand All @@ -29,9 +29,9 @@ class CliOptions:
assume_unique_key: bool = False
sample_exclusive_rows: bool = False
materialize_all_rows: bool = False
threads: Union[None, int, Literal["serial"]] = None
threads1: Optional[int] = None
threads2: Optional[int] = None
threads: Union[PositiveInt, Literal["serial"]] = 1
threads1: Optional[PositiveInt] = None
threads2: Optional[PositiveInt] = None
threaded: bool = False
where: Optional[str] = None
algorithm: Literal["auto", "joindiff", "hashdiff"] = None
Expand All @@ -46,3 +46,13 @@ class CliOptions:
prod_database: Optional[str] = None
prod_schema: Optional[str] = None
__conf__: Optional[Dict] = None

@model_validator(mode="after")
def check_threads(self) -> "CliOptions":
self.threaded = True
if self.threads == "serial":
if self.threads1 or self.threads2:
raise ValueError("threads1 and threads2 can not be set when threads is set to serial.")
self.threads = 1
self.threaded = False
return self
4 changes: 2 additions & 2 deletions data_diff/cloud/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from typing import List, Optional, Union, overload

import pydantic
from pydantic import BaseModel
import rich
from rich.table import Table
from rich.prompt import Confirm, Prompt, FloatPrompt, IntPrompt, InvalidResponse
Expand All @@ -21,7 +21,7 @@
UNKNOWN_VALUE = "unknown_value"


class TDataSourceTestStage(pydantic.BaseModel):
class TDataSourceTestStage(BaseModel):
name: str
status: TestDataSourceStatus
description: str = ""
Expand Down
32 changes: 16 additions & 16 deletions data_diff/cloud/datafold_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import base64
import enum
import time
from typing import Any, Dict, List, Optional, Type, Tuple
from typing import Any, Dict, List, Optional, Tuple

import attrs
import pydantic
import requests
from pydantic import BaseModel
from typing_extensions import Self

from data_diff.errors import DataDiffCloudDiffFailed, DataDiffCloudDiffTimedOut, DataDiffDatasourceIdNotFoundError
Expand All @@ -21,7 +21,7 @@ class TestDataSourceStatus(str, enum.Enum):
UNKNOWN = "unknown"


class TCloudApiDataSourceSchema(pydantic.BaseModel):
class TCloudApiDataSourceSchema(BaseModel):
title: str
properties: Dict[str, Dict[str, Any]]
required: List[str]
Expand All @@ -48,13 +48,13 @@ def from_orm(cls, obj: Any) -> Self:
)


class TCloudApiDataSourceConfigSchema(pydantic.BaseModel):
class TCloudApiDataSourceConfigSchema(BaseModel):
name: str
db_type: str
config_schema: TCloudApiDataSourceSchema


class TCloudApiDataSource(pydantic.BaseModel):
class TCloudApiDataSource(BaseModel):
id: Optional[int] = None
name: str
type: str
Expand Down Expand Up @@ -85,7 +85,7 @@ class TCloudApiDataSource(pydantic.BaseModel):
secret_id: Optional[int] = None


class TDsConfig(pydantic.BaseModel):
class TDsConfig(BaseModel):
name: str
type: str
temp_schema: str
Expand All @@ -95,7 +95,7 @@ class TDsConfig(pydantic.BaseModel):
disable_profiling: bool = True


class TCloudApiDataDiff(pydantic.BaseModel):
class TCloudApiDataDiff(BaseModel):
data_source1_id: int
data_source2_id: int
table1: List[str]
Expand All @@ -107,26 +107,26 @@ class TCloudApiDataDiff(pydantic.BaseModel):
exclude_columns: Optional[List[str]]


class TCloudApiOrgMeta(pydantic.BaseModel):
class TCloudApiOrgMeta(BaseModel):
org_id: int
org_name: str
user_id: int


class TSummaryResultPrimaryKeyStats(pydantic.BaseModel):
class TSummaryResultPrimaryKeyStats(BaseModel):
total_rows: Tuple[int, int]
nulls: Tuple[int, int]
dupes: Tuple[int, int]
exclusives: Tuple[int, int]
distincts: Tuple[int, int]


class TSummaryResultColumnDiffStats(pydantic.BaseModel):
class TSummaryResultColumnDiffStats(BaseModel):
column_name: str
match: float


class TSummaryResultValueStats(pydantic.BaseModel):
class TSummaryResultValueStats(BaseModel):
total_rows: int
rows_with_differences: int
total_values: int
Expand All @@ -135,7 +135,7 @@ class TSummaryResultValueStats(pydantic.BaseModel):
columns_diff_stats: List[TSummaryResultColumnDiffStats]


class TSummaryResultSchemaStats(pydantic.BaseModel):
class TSummaryResultSchemaStats(BaseModel):
columns_mismatched: Tuple[int, int]
column_type_mismatches: int
column_reorders: int
Expand All @@ -144,11 +144,11 @@ class TSummaryResultSchemaStats(pydantic.BaseModel):
exclusive_columns: Tuple[List[str], List[str]]


class TSummaryResultDependencyDetails(pydantic.BaseModel):
class TSummaryResultDependencyDetails(BaseModel):
deps: Dict[str, List[Dict]]


class TCloudApiDataDiffSummaryResult(pydantic.BaseModel):
class TCloudApiDataDiffSummaryResult(BaseModel):
status: str
pks: Optional[TSummaryResultPrimaryKeyStats]
values: Optional[TSummaryResultValueStats]
Expand All @@ -170,13 +170,13 @@ def from_orm(cls, obj: Any) -> Self:
)


class TCloudDataSourceTestResult(pydantic.BaseModel):
class TCloudDataSourceTestResult(BaseModel):
status: TestDataSourceStatus
message: str
outcome: str


class TCloudApiDataSourceTestResult(pydantic.BaseModel):
class TCloudApiDataSourceTestResult(BaseModel):
name: str
status: str
result: Optional[TCloudDataSourceTestResult]
Expand Down
20 changes: 10 additions & 10 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from contextlib import nullcontext
import json
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import nullcontext
from typing import List, Optional, Dict, Tuple, Union

import keyring
import pydantic
import rich
from rich.prompt import Prompt
from pydantic import BaseModel
from rich.markdown import Markdown
from concurrent.futures import ThreadPoolExecutor, as_completed
from rich.prompt import Prompt

from data_diff import connect_to_table, diff_tables, Algorithm
from data_diff.cli_options import CliOptions
from data_diff.cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta
from data_diff.dbt_parser import DbtParser, TDatadiffConfig
from data_diff.diff_tables import DiffResultWrapper
from data_diff.errors import (
DataDiffCustomSchemaNoConfigError,
DataDiffDbtProjectVarsNotFoundError,
DataDiffNoAPIKeyError,
DataDiffNoDatasourceIdError,
)

from data_diff import connect_to_table, diff_tables, Algorithm
from data_diff.cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta
from data_diff.dbt_parser import DbtParser, TDatadiffConfig
from data_diff.diff_tables import DiffResultWrapper
from data_diff.format import jsonify, jsonify_error
from data_diff.tracking import (
bool_ask_for_email,
Expand Down Expand Up @@ -56,7 +56,7 @@
DATAFOLD_INSTRUCTIONS_URL = "https://docs.datafold.com/development_testing/datafold_cloud"


class TDiffVars(pydantic.BaseModel):
class TDiffVars(BaseModel):
dev_path: List[str]
prod_path: List[str]
primary_keys: List[str]
Expand Down
4 changes: 2 additions & 2 deletions data_diff/dbt_config_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class ManifestJsonConfig(BaseModel):
class Metadata(BaseModel):
dbt_version: str = Field(..., regex=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$")
dbt_version: str = Field(..., pattern=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$")
project_id: Optional[str]
user_id: Optional[str]

Expand Down Expand Up @@ -46,7 +46,7 @@ class DependsOn(BaseModel):

class RunResultsJsonConfig(BaseModel):
class Metadata(BaseModel):
dbt_version: str = Field(..., regex=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$")
dbt_version: str = Field(..., pattern=r"^\d+\.\d+\.\d+([a-zA-Z0-9]+)?$")

class Results(BaseModel):
class Status(Enum):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
]
packages = [{ include = "data_diff" }]
[tool.poetry.dependencies]
pydantic = "1.10.12"
pydantic = "^2.6.0"
python = "^3.8.0"
dsnparse = "<0.2.0"
click = "^8.1"
Expand Down
4 changes: 2 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,5 +187,5 @@ def ansi_stdout_cleanup(ansi_input) -> str:
return re.sub(r"\x1B\[[0-?]*[ -/]*[@-~]", "", ansi_input)


def get_cli_options() -> CliOptions:
return CliOptions(bisection_factor=2, bisection_threshold=3, table_write_limit=1)
def get_cli_options(**kwargs) -> CliOptions:
return CliOptions(bisection_factor=2, bisection_threshold=3, table_write_limit=1, **kwargs)
Loading

0 comments on commit 965ead0

Please sign in to comment.