Skip to content

Commit

Permalink
Chore!: remove pydantic v1 validator arg helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed Jan 17, 2025
1 parent 7fd6b7d commit fc35f96
Show file tree
Hide file tree
Showing 18 changed files with 360 additions and 438 deletions.
6 changes: 6 additions & 0 deletions sqlmesh/core/_typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import typing as t

from sqlglot import exp
Expand All @@ -9,3 +10,8 @@
SchemaName = t.Union[str, exp.Table]
SessionProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]
CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]

if sys.version_info >= (3, 11):
from typing import Self as Self
else:
from typing_extensions import Self as Self
18 changes: 6 additions & 12 deletions sqlmesh/core/audit/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@
extract_macro_references_and_variables,
)
from sqlmesh.utils.metaprogramming import Executable
from sqlmesh.utils.pydantic import (
PydanticModel,
field_validator,
model_validator,
model_validator_v1_args,
)
from sqlmesh.utils.pydantic import PydanticModel, field_validator, model_validator

if t.TYPE_CHECKING:
from sqlmesh.core._typing import Self
from sqlmesh.core.snapshot import DeployabilityIndex, Node, Snapshot


Expand Down Expand Up @@ -175,12 +171,10 @@ class StandaloneAudit(_Node, AuditMixin):
_depends_on_validator = depends_on_validator

@model_validator(mode="after")
@model_validator_v1_args
def _node_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
if values.get("blocking"):
name = values.get("name")
raise AuditConfigError(f"Standalone audits cannot be blocking: '{name}'.")
return values
def _node_root_validator(self) -> Self:
if self.blocking:
raise AuditConfigError(f"Standalone audits cannot be blocking: '{self.name}'.")
return self

def render_audit_query(
self,
Expand Down
187 changes: 91 additions & 96 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@
from sqlmesh.core.engine_adapter.shared import CatalogSupport
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.pydantic import (
field_validator,
model_validator,
model_validator_v1_args,
field_validator_v1_args,
)
from sqlmesh.utils.pydantic import ValidationInfo, field_validator, model_validator
from sqlmesh.utils.aws import validate_s3_uri

if t.TYPE_CHECKING:
from sqlmesh.core._typing import Self

logger = logging.getLogger(__name__)

RECOMMENDED_STATE_SYNC_ENGINES = {"postgres", "gcp_postgres", "mysql", "mssql"}
Expand Down Expand Up @@ -163,19 +161,20 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):
_data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}

@model_validator(mode="before")
@model_validator_v1_args
def _validate_database_catalogs(
cls, values: t.Dict[str, t.Optional[str]]
) -> t.Dict[str, t.Optional[str]]:
if db_path := values.get("database") and values.get("catalogs"):
def _validate_database_catalogs(cls, data: t.Any) -> t.Any:
if not isinstance(data, dict):
return data

if db_path := data.get("database") and data.get("catalogs"):
raise ConfigError(
"Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog"
)
if isinstance(db_path, str) and db_path.startswith("md:"):
raise ConfigError(
"Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`."
)
return values

return data

@property
def _engine_adapter(self) -> t.Type[EngineAdapter]:
Expand Down Expand Up @@ -430,29 +429,29 @@ class SnowflakeConnectionConfig(ConnectionConfig):
_concurrent_tasks_validator = concurrent_tasks_validator

@model_validator(mode="before")
@model_validator_v1_args
def _validate_authenticator(
cls, values: t.Dict[str, t.Optional[str]]
) -> t.Dict[str, t.Optional[str]]:
from snowflake.connector.network import (
DEFAULT_AUTHENTICATOR,
OAUTH_AUTHENTICATOR,
)
def _validate_authenticator(cls, data: t.Any) -> t.Any:
if not isinstance(data, dict):
return data

auth = values.get("authenticator")
from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR

auth = data.get("authenticator")
auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR
user = values.get("user")
password = values.get("password")
values["private_key"] = cls._get_private_key(values, auth) # type: ignore
user = data.get("user")
password = data.get("password")
data["private_key"] = cls._get_private_key(data, auth) # type: ignore

if (
auth == DEFAULT_AUTHENTICATOR
and not values.get("private_key")
and not data.get("private_key")
and (not user or not password)
):
raise ConfigError("User and password must be provided if using default authentication")
if auth == OAUTH_AUTHENTICATOR and not values.get("token"):

if auth == OAUTH_AUTHENTICATOR and not data.get("token"):
raise ConfigError("Token must be provided if using oauth authentication")
return values

return data

@classmethod
def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]:
Expand Down Expand Up @@ -621,26 +620,28 @@ class DatabricksConnectionConfig(ConnectionConfig):
_http_headers_validator = http_headers_validator

@model_validator(mode="before")
@model_validator_v1_args
def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
if not isinstance(data, dict):
return data

from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter

if DatabricksEngineAdapter.can_access_spark_session(
bool(values.get("disable_spark_session"))
bool(data.get("disable_spark_session"))
):
return values
return data

databricks_connect_use_serverless = values.get("databricks_connect_use_serverless")
databricks_connect_use_serverless = data.get("databricks_connect_use_serverless")
server_hostname, http_path, access_token, auth_type = (
values.get("server_hostname"),
values.get("http_path"),
values.get("access_token"),
values.get("auth_type"),
data.get("server_hostname"),
data.get("http_path"),
data.get("access_token"),
data.get("auth_type"),
)

if databricks_connect_use_serverless:
values["force_databricks_connect"] = True
values["disable_databricks_connect"] = False
data["force_databricks_connect"] = True
data["disable_databricks_connect"] = False

if (not server_hostname or not http_path or not access_token) and (
not databricks_connect_use_serverless and not auth_type
Expand All @@ -651,35 +652,35 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
if (
databricks_connect_use_serverless
and not server_hostname
and not values.get("databricks_connect_server_hostname")
and not data.get("databricks_connect_server_hostname")
):
raise ValueError(
"`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set"
)
if DatabricksEngineAdapter.can_access_databricks_connect(
bool(values.get("disable_databricks_connect"))
bool(data.get("disable_databricks_connect"))
):
if not values.get("databricks_connect_access_token"):
values["databricks_connect_access_token"] = access_token
if not values.get("databricks_connect_server_hostname"):
values["databricks_connect_server_hostname"] = f"https://{server_hostname}"
if not data.get("databricks_connect_access_token"):
data["databricks_connect_access_token"] = access_token
if not data.get("databricks_connect_server_hostname"):
data["databricks_connect_server_hostname"] = f"https://{server_hostname}"
if not databricks_connect_use_serverless:
if not values.get("databricks_connect_cluster_id"):
if not data.get("databricks_connect_cluster_id"):
if t.TYPE_CHECKING:
assert http_path is not None
values["databricks_connect_cluster_id"] = http_path.split("/")[-1]
data["databricks_connect_cluster_id"] = http_path.split("/")[-1]

if auth_type:
from databricks.sql.auth.auth import AuthType

all_values = [m.value for m in AuthType]
if auth_type not in all_values:
all_data = [m.value for m in AuthType]
if auth_type not in all_data:
raise ValueError(
f"`auth_type` {auth_type} does not match a valid option: {all_values}"
f"`auth_type` {auth_type} does not match a valid option: {all_data}"
)

client_id = values.get("oauth_client_id")
client_secret = values.get("oauth_client_secret")
client_id = data.get("oauth_client_id")
client_secret = data.get("oauth_client_secret")

if client_secret and not client_id:
raise ValueError(
Expand All @@ -689,7 +690,7 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
if not http_path:
raise ValueError("`http_path` is still required when using `auth_type`")

return values
return data

@property
def _connection_kwargs_keys(self) -> t.Set[str]:
Expand Down Expand Up @@ -866,26 +867,24 @@ class BigQueryConnectionConfig(ConnectionConfig):
type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery")

@field_validator("execution_project")
@field_validator_v1_args
def validate_execution_project(
cls,
v: t.Optional[str],
values: t.Dict[str, t.Any],
info: ValidationInfo,
) -> t.Optional[str]:
if v and not values.get("project"):
if v and not info.data.get("project"):
raise ConfigError(
"If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
)
return v

@field_validator("quota_project")
@field_validator_v1_args
def validate_quota_project(
cls,
v: t.Optional[str],
values: t.Dict[str, t.Any],
info: ValidationInfo,
) -> t.Optional[str]:
if v and not values.get("project"):
if v and not info.data.get("project"):
raise ConfigError(
"If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
)
Expand Down Expand Up @@ -998,12 +997,13 @@ class GCPPostgresConnectionConfig(ConnectionConfig):
pre_ping: bool = True

@model_validator(mode="before")
@model_validator_v1_args
def _validate_auth_method(
cls, values: t.Dict[str, t.Optional[str]]
) -> t.Dict[str, t.Optional[str]]:
password = values.get("password")
enable_iam_auth = values.get("enable_iam_auth")
def _validate_auth_method(cls, data: t.Any) -> t.Any:
if not isinstance(data, dict):
return data

password = data.get("password")
enable_iam_auth = data.get("enable_iam_auth")

if password and enable_iam_auth:
raise ConfigError(
"Invalid GCP Postgres connection configuration - both password and"
Expand All @@ -1016,7 +1016,8 @@ def _validate_auth_method(
" for a postgres user account or enable_iam_auth set to 'True'"
" for an IAM user account."
)
return values

return data

@property
def _connection_kwargs_keys(self) -> t.Set[str]:
Expand Down Expand Up @@ -1437,40 +1438,37 @@ class TrinoConnectionConfig(ConnectionConfig):
type_: t.Literal["trino"] = Field(alias="type", default="trino")

@model_validator(mode="after")
@model_validator_v1_args
def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
port = values.get("port")
if (
values["http_scheme"] == "http"
and not values["method"].is_no_auth
and not values["method"].is_basic
):
def _root_validator(self) -> Self:
port = self.port
if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic:
raise ConfigError("HTTP scheme can only be used with no-auth or basic method")

if port is None:
values["port"] = 80 if values["http_scheme"] == "http" else 443
if (values["method"].is_ldap or values["method"].is_basic) and (
not values["password"] or not values["user"]
):
self.port = 80 if self.http_scheme == "http" else 443

if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user):
raise ConfigError(
f"Username and Password must be provided if using {values['method'].value} authentication"
f"Username and Password must be provided if using {self.method.value} authentication"
)
if values["method"].is_kerberos and (
not values["principal"] or not values["keytab"] or not values["krb5_config"]

if self.method.is_kerberos and (
not self.principal or not self.keytab or not self.krb5_config
):
raise ConfigError(
"Kerberos requires the following fields: principal, keytab, and krb5_config"
)
if values["method"].is_jwt and not values["jwt_token"]:

if self.method.is_jwt and not self.jwt_token:
raise ConfigError("JWT requires `jwt_token` to be set")
if values["method"].is_certificate and (
not values["cert"]
or not values["client_certificate"]
or not values["client_private_key"]

if self.method.is_certificate and (
not self.cert or not self.client_certificate or not self.client_private_key
):
raise ConfigError(
"Certificate requires the following fields: cert, client_certificate, and client_private_key"
)
return values

return self

@property
def _connection_kwargs_keys(self) -> t.Set[str]:
Expand Down Expand Up @@ -1677,26 +1675,23 @@ class AthenaConnectionConfig(ConnectionConfig):
type_: t.Literal["athena"] = Field(alias="type", default="athena")

@model_validator(mode="after")
@model_validator_v1_args
def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
work_group = values.get("work_group")
s3_staging_dir = values.get("s3_staging_dir")
s3_warehouse_location = values.get("s3_warehouse_location")
def _root_validator(self) -> Self:
work_group = self.work_group
s3_staging_dir = self.s3_staging_dir
s3_warehouse_location = self.s3_warehouse_location

if not work_group and not s3_staging_dir:
raise ConfigError("At least one of work_group or s3_staging_dir must be set")

if s3_staging_dir:
values["s3_staging_dir"] = validate_s3_uri(
s3_staging_dir, base=True, error_type=ConfigError
)
self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError)

if s3_warehouse_location:
values["s3_warehouse_location"] = validate_s3_uri(
self.s3_warehouse_location = validate_s3_uri(
s3_warehouse_location, base=True, error_type=ConfigError
)

return values
return self

@property
def _connection_kwargs_keys(self) -> t.Set[str]:
Expand Down
Loading

0 comments on commit fc35f96

Please sign in to comment.