Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore!: remove pydantic v1 validator arg helpers #3615

Merged
merged 5 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sqlmesh/core/_typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import typing as t

from sqlglot import exp
Expand All @@ -9,3 +10,8 @@
SchemaName = t.Union[str, exp.Table]
SessionProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]
CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]

if sys.version_info >= (3, 11):
from typing import Self as Self
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
else:
from typing_extensions import Self as Self
18 changes: 6 additions & 12 deletions sqlmesh/core/audit/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@
extract_macro_references_and_variables,
)
from sqlmesh.utils.metaprogramming import Executable
from sqlmesh.utils.pydantic import (
PydanticModel,
field_validator,
model_validator,
model_validator_v1_args,
)
from sqlmesh.utils.pydantic import PydanticModel, field_validator, model_validator

if t.TYPE_CHECKING:
from sqlmesh.core._typing import Self
from sqlmesh.core.snapshot import DeployabilityIndex, Node, Snapshot


Expand Down Expand Up @@ -175,12 +171,10 @@ class StandaloneAudit(_Node, AuditMixin):
_depends_on_validator = depends_on_validator

@model_validator(mode="after")
@model_validator_v1_args
def _node_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
if values.get("blocking"):
name = values.get("name")
raise AuditConfigError(f"Standalone audits cannot be blocking: '{name}'.")
return values
def _node_root_validator(self) -> Self:
if self.blocking:
raise AuditConfigError(f"Standalone audits cannot be blocking: '{self.name}'.")
return self

def render_audit_query(
self,
Expand Down
187 changes: 91 additions & 96 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@
from sqlmesh.core.engine_adapter.shared import CatalogSupport
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.pydantic import (
field_validator,
model_validator,
model_validator_v1_args,
field_validator_v1_args,
)
from sqlmesh.utils.pydantic import ValidationInfo, field_validator, model_validator
from sqlmesh.utils.aws import validate_s3_uri

if t.TYPE_CHECKING:
from sqlmesh.core._typing import Self

logger = logging.getLogger(__name__)

RECOMMENDED_STATE_SYNC_ENGINES = {"postgres", "gcp_postgres", "mysql", "mssql"}
Expand Down Expand Up @@ -163,19 +161,20 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):
_data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}

@model_validator(mode="before")
@model_validator_v1_args
def _validate_database_catalogs(
cls, values: t.Dict[str, t.Optional[str]]
) -> t.Dict[str, t.Optional[str]]:
if db_path := values.get("database") and values.get("catalogs"):
def _validate_database_catalogs(cls, data: t.Any) -> t.Any:
if not isinstance(data, dict):
return data

if db_path := data.get("database") and data.get("catalogs"):
raise ConfigError(
"Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog"
)
if isinstance(db_path, str) and db_path.startswith("md:"):
raise ConfigError(
"Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`."
)
return values

return data

@property
def _engine_adapter(self) -> t.Type[EngineAdapter]:
Expand Down Expand Up @@ -430,29 +429,29 @@ class SnowflakeConnectionConfig(ConnectionConfig):
_concurrent_tasks_validator = concurrent_tasks_validator

@model_validator(mode="before")
@model_validator_v1_args
def _validate_authenticator(
cls, values: t.Dict[str, t.Optional[str]]
) -> t.Dict[str, t.Optional[str]]:
from snowflake.connector.network import (
DEFAULT_AUTHENTICATOR,
OAUTH_AUTHENTICATOR,
)
def _validate_authenticator(cls, data: t.Any) -> t.Any:
if not isinstance(data, dict):
return data

auth = values.get("authenticator")
from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR

auth = data.get("authenticator")
auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR
user = values.get("user")
password = values.get("password")
values["private_key"] = cls._get_private_key(values, auth) # type: ignore
user = data.get("user")
password = data.get("password")
data["private_key"] = cls._get_private_key(data, auth) # type: ignore

if (
auth == DEFAULT_AUTHENTICATOR
and not values.get("private_key")
and not data.get("private_key")
and (not user or not password)
):
raise ConfigError("User and password must be provided if using default authentication")
if auth == OAUTH_AUTHENTICATOR and not values.get("token"):

if auth == OAUTH_AUTHENTICATOR and not data.get("token"):
raise ConfigError("Token must be provided if using oauth authentication")
return values

return data

@classmethod
def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]:
Expand Down Expand Up @@ -621,26 +620,28 @@ class DatabricksConnectionConfig(ConnectionConfig):
_http_headers_validator = http_headers_validator

@model_validator(mode="before")
@model_validator_v1_args
def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
if not isinstance(data, dict):
return data

from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter

if DatabricksEngineAdapter.can_access_spark_session(
bool(values.get("disable_spark_session"))
bool(data.get("disable_spark_session"))
):
return values
return data

databricks_connect_use_serverless = values.get("databricks_connect_use_serverless")
databricks_connect_use_serverless = data.get("databricks_connect_use_serverless")
server_hostname, http_path, access_token, auth_type = (
values.get("server_hostname"),
values.get("http_path"),
values.get("access_token"),
values.get("auth_type"),
data.get("server_hostname"),
data.get("http_path"),
data.get("access_token"),
data.get("auth_type"),
)

if databricks_connect_use_serverless:
values["force_databricks_connect"] = True
values["disable_databricks_connect"] = False
data["force_databricks_connect"] = True
data["disable_databricks_connect"] = False

if (not server_hostname or not http_path or not access_token) and (
not databricks_connect_use_serverless and not auth_type
Expand All @@ -651,35 +652,35 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
if (
databricks_connect_use_serverless
and not server_hostname
and not values.get("databricks_connect_server_hostname")
and not data.get("databricks_connect_server_hostname")
):
raise ValueError(
"`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set"
)
if DatabricksEngineAdapter.can_access_databricks_connect(
bool(values.get("disable_databricks_connect"))
bool(data.get("disable_databricks_connect"))
):
if not values.get("databricks_connect_access_token"):
values["databricks_connect_access_token"] = access_token
if not values.get("databricks_connect_server_hostname"):
values["databricks_connect_server_hostname"] = f"https://{server_hostname}"
if not data.get("databricks_connect_access_token"):
data["databricks_connect_access_token"] = access_token
if not data.get("databricks_connect_server_hostname"):
data["databricks_connect_server_hostname"] = f"https://{server_hostname}"
if not databricks_connect_use_serverless:
if not values.get("databricks_connect_cluster_id"):
if not data.get("databricks_connect_cluster_id"):
if t.TYPE_CHECKING:
assert http_path is not None
values["databricks_connect_cluster_id"] = http_path.split("/")[-1]
data["databricks_connect_cluster_id"] = http_path.split("/")[-1]

if auth_type:
from databricks.sql.auth.auth import AuthType

all_values = [m.value for m in AuthType]
if auth_type not in all_values:
all_data = [m.value for m in AuthType]
if auth_type not in all_data:
raise ValueError(
f"`auth_type` {auth_type} does not match a valid option: {all_values}"
f"`auth_type` {auth_type} does not match a valid option: {all_data}"
)

client_id = values.get("oauth_client_id")
client_secret = values.get("oauth_client_secret")
client_id = data.get("oauth_client_id")
client_secret = data.get("oauth_client_secret")

if client_secret and not client_id:
raise ValueError(
Expand All @@ -689,7 +690,7 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
if not http_path:
raise ValueError("`http_path` is still required when using `auth_type`")

return values
return data

@property
def _connection_kwargs_keys(self) -> t.Set[str]:
Expand Down Expand Up @@ -866,26 +867,24 @@ class BigQueryConnectionConfig(ConnectionConfig):
type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery")

@field_validator("execution_project")
@field_validator_v1_args
def validate_execution_project(
cls,
v: t.Optional[str],
values: t.Dict[str, t.Any],
info: ValidationInfo,
) -> t.Optional[str]:
if v and not values.get("project"):
if v and not info.data.get("project"):
raise ConfigError(
"If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
)
return v

@field_validator("quota_project")
@field_validator_v1_args
def validate_quota_project(
cls,
v: t.Optional[str],
values: t.Dict[str, t.Any],
info: ValidationInfo,
) -> t.Optional[str]:
if v and not values.get("project"):
if v and not info.data.get("project"):
raise ConfigError(
"If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
)
Expand Down Expand Up @@ -998,12 +997,13 @@ class GCPPostgresConnectionConfig(ConnectionConfig):
pre_ping: bool = True

@model_validator(mode="before")
@model_validator_v1_args
def _validate_auth_method(
cls, values: t.Dict[str, t.Optional[str]]
) -> t.Dict[str, t.Optional[str]]:
password = values.get("password")
enable_iam_auth = values.get("enable_iam_auth")
def _validate_auth_method(cls, data: t.Any) -> t.Any:
if not isinstance(data, dict):
return data

password = data.get("password")
enable_iam_auth = data.get("enable_iam_auth")

if password and enable_iam_auth:
raise ConfigError(
"Invalid GCP Postgres connection configuration - both password and"
Expand All @@ -1016,7 +1016,8 @@ def _validate_auth_method(
" for a postgres user account or enable_iam_auth set to 'True'"
" for an IAM user account."
)
return values

return data

@property
def _connection_kwargs_keys(self) -> t.Set[str]:
Expand Down Expand Up @@ -1437,40 +1438,37 @@ class TrinoConnectionConfig(ConnectionConfig):
type_: t.Literal["trino"] = Field(alias="type", default="trino")

@model_validator(mode="after")
@model_validator_v1_args
def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
port = values.get("port")
if (
values["http_scheme"] == "http"
and not values["method"].is_no_auth
and not values["method"].is_basic
):
def _root_validator(self) -> Self:
port = self.port
if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic:
raise ConfigError("HTTP scheme can only be used with no-auth or basic method")

if port is None:
values["port"] = 80 if values["http_scheme"] == "http" else 443
if (values["method"].is_ldap or values["method"].is_basic) and (
not values["password"] or not values["user"]
):
self.port = 80 if self.http_scheme == "http" else 443

if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user):
raise ConfigError(
f"Username and Password must be provided if using {values['method'].value} authentication"
f"Username and Password must be provided if using {self.method.value} authentication"
)
if values["method"].is_kerberos and (
not values["principal"] or not values["keytab"] or not values["krb5_config"]

if self.method.is_kerberos and (
not self.principal or not self.keytab or not self.krb5_config
):
raise ConfigError(
"Kerberos requires the following fields: principal, keytab, and krb5_config"
)
if values["method"].is_jwt and not values["jwt_token"]:

if self.method.is_jwt and not self.jwt_token:
raise ConfigError("JWT requires `jwt_token` to be set")
if values["method"].is_certificate and (
not values["cert"]
or not values["client_certificate"]
or not values["client_private_key"]

if self.method.is_certificate and (
not self.cert or not self.client_certificate or not self.client_private_key
):
raise ConfigError(
"Certificate requires the following fields: cert, client_certificate, and client_private_key"
)
return values

return self

@property
def _connection_kwargs_keys(self) -> t.Set[str]:
Expand Down Expand Up @@ -1677,26 +1675,23 @@ class AthenaConnectionConfig(ConnectionConfig):
type_: t.Literal["athena"] = Field(alias="type", default="athena")

@model_validator(mode="after")
@model_validator_v1_args
def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
work_group = values.get("work_group")
s3_staging_dir = values.get("s3_staging_dir")
s3_warehouse_location = values.get("s3_warehouse_location")
def _root_validator(self) -> Self:
work_group = self.work_group
s3_staging_dir = self.s3_staging_dir
s3_warehouse_location = self.s3_warehouse_location

if not work_group and not s3_staging_dir:
raise ConfigError("At least one of work_group or s3_staging_dir must be set")

if s3_staging_dir:
values["s3_staging_dir"] = validate_s3_uri(
s3_staging_dir, base=True, error_type=ConfigError
)
self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError)

if s3_warehouse_location:
values["s3_warehouse_location"] = validate_s3_uri(
self.s3_warehouse_location = validate_s3_uri(
s3_warehouse_location, base=True, error_type=ConfigError
)

return values
return self

@property
def _connection_kwargs_keys(self) -> t.Set[str]:
Expand Down
Loading