From 1bbcd4f543347e7e49df0e3180141182084526a0 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Fri, 10 Jan 2025 20:48:59 +0200 Subject: [PATCH 1/5] Chore!: remove pydantic v1 validator arg helpers --- sqlmesh/core/_typing.py | 6 + sqlmesh/core/audit/definition.py | 18 +- sqlmesh/core/config/connection.py | 187 ++++++++++----------- sqlmesh/core/config/root.py | 58 ++++--- sqlmesh/core/config/scheduler.py | 13 +- sqlmesh/core/metric/definition.py | 17 +- sqlmesh/core/model/common.py | 22 +-- sqlmesh/core/model/definition.py | 7 +- sqlmesh/core/model/kind.py | 118 ++++++------- sqlmesh/core/model/meta.py | 106 ++++++------ sqlmesh/core/node.py | 28 ++- sqlmesh/core/state_sync/base.py | 11 +- sqlmesh/core/user.py | 16 +- sqlmesh/dbt/target.py | 122 +++++++------- sqlmesh/integrations/github/cicd/config.py | 18 +- sqlmesh/utils/pydantic.py | 30 +--- tests/core/test_model.py | 6 +- web/server/models.py | 15 +- 18 files changed, 360 insertions(+), 438 deletions(-) diff --git a/sqlmesh/core/_typing.py b/sqlmesh/core/_typing.py index 197d07bf2..e495df169 100644 --- a/sqlmesh/core/_typing.py +++ b/sqlmesh/core/_typing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import typing as t from sqlglot import exp @@ -9,3 +10,8 @@ SchemaName = t.Union[str, exp.Table] SessionProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]] CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]] + +if sys.version_info >= (3, 11): + from typing import Self as Self +else: + from typing_extensions import Self as Self diff --git a/sqlmesh/core/audit/definition.py b/sqlmesh/core/audit/definition.py index 9c4ce3b79..f56662c3a 100644 --- a/sqlmesh/core/audit/definition.py +++ b/sqlmesh/core/audit/definition.py @@ -28,14 +28,10 @@ extract_macro_references_and_variables, ) from sqlmesh.utils.metaprogramming import Executable -from sqlmesh.utils.pydantic import ( - PydanticModel, - field_validator, - model_validator, - model_validator_v1_args, -) +from sqlmesh.utils.pydantic import PydanticModel, field_validator, model_validator if t.TYPE_CHECKING: + from sqlmesh.core._typing import Self from sqlmesh.core.snapshot import DeployabilityIndex, Node, Snapshot @@ -175,12 +171,10 @@ class StandaloneAudit(_Node, AuditMixin): _depends_on_validator = depends_on_validator @model_validator(mode="after") - @model_validator_v1_args - def _node_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - if values.get("blocking"): - name = values.get("name") - raise AuditConfigError(f"Standalone audits cannot be blocking: '{name}'.") - return values + def _node_root_validator(self) -> Self: + if self.blocking: + raise AuditConfigError(f"Standalone audits cannot be blocking: '{self.name}'.") + return self def render_audit_query( self, diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index d32c879ba..ac5d27c86 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -23,14 +23,12 @@ from sqlmesh.core.engine_adapter.shared import CatalogSupport from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.pydantic import ( - field_validator, - model_validator, - model_validator_v1_args, - field_validator_v1_args, -) +from sqlmesh.utils.pydantic import ValidationInfo, field_validator, model_validator from sqlmesh.utils.aws import validate_s3_uri +if t.TYPE_CHECKING: + from sqlmesh.core._typing import Self + logger = logging.getLogger(__name__) RECOMMENDED_STATE_SYNC_ENGINES = {"postgres", "gcp_postgres", "mysql", "mssql"} @@ -163,11 +161,11 @@ class BaseDuckDBConnectionConfig(ConnectionConfig): _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} @model_validator(mode="before") - @model_validator_v1_args - def _validate_database_catalogs( - cls, values: t.Dict[str, t.Optional[str]] - ) -> t.Dict[str, t.Optional[str]]: - if db_path := values.get("database") and values.get("catalogs"): + def _validate_database_catalogs(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + if db_path := data.get("database") and data.get("catalogs"): raise ConfigError( "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog" ) @@ -175,7 +173,8 @@ def _validate_database_catalogs( raise ConfigError( "Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`." ) - return values + + return data @property def _engine_adapter(self) -> t.Type[EngineAdapter]: @@ -430,29 +429,29 @@ class SnowflakeConnectionConfig(ConnectionConfig): _concurrent_tasks_validator = concurrent_tasks_validator @model_validator(mode="before") - @model_validator_v1_args - def _validate_authenticator( - cls, values: t.Dict[str, t.Optional[str]] - ) -> t.Dict[str, t.Optional[str]]: - from snowflake.connector.network import ( - DEFAULT_AUTHENTICATOR, - OAUTH_AUTHENTICATOR, - ) + def _validate_authenticator(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data - auth = values.get("authenticator") + from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR + + auth = data.get("authenticator") auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR - user = values.get("user") - password = values.get("password") - values["private_key"] = cls._get_private_key(values, auth) # type: ignore + user = data.get("user") + password = data.get("password") + data["private_key"] = cls._get_private_key(data, auth) # type: ignore + if ( auth == DEFAULT_AUTHENTICATOR - and not values.get("private_key") + and not data.get("private_key") and (not user or not password) ): raise ConfigError("User and password must be provided if using default authentication") - if auth == OAUTH_AUTHENTICATOR and not values.get("token"): + + if auth == OAUTH_AUTHENTICATOR and not data.get("token"): raise ConfigError("Token must be provided if using oauth authentication") - return values + + return data @classmethod def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]: @@ -621,26 +620,28 @@ class DatabricksConnectionConfig(ConnectionConfig): _http_headers_validator = http_headers_validator @model_validator(mode="before") - @model_validator_v1_args - def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + def _databricks_connect_validator(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter if DatabricksEngineAdapter.can_access_spark_session( - bool(values.get("disable_spark_session")) + bool(data.get("disable_spark_session")) ): - return values + return data - databricks_connect_use_serverless = values.get("databricks_connect_use_serverless") + databricks_connect_use_serverless = data.get("databricks_connect_use_serverless") server_hostname, http_path, access_token, auth_type = ( - values.get("server_hostname"), - values.get("http_path"), - values.get("access_token"), - values.get("auth_type"), + data.get("server_hostname"), + data.get("http_path"), + data.get("access_token"), + data.get("auth_type"), ) if databricks_connect_use_serverless: - values["force_databricks_connect"] = True - values["disable_databricks_connect"] = False + data["force_databricks_connect"] = True + data["disable_databricks_connect"] = False if (not server_hostname or not http_path or not access_token) and ( not databricks_connect_use_serverless and not auth_type @@ -651,35 +652,35 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str if ( databricks_connect_use_serverless and not server_hostname - and not values.get("databricks_connect_server_hostname") + and not data.get("databricks_connect_server_hostname") ): raise ValueError( "`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set" ) if DatabricksEngineAdapter.can_access_databricks_connect( - bool(values.get("disable_databricks_connect")) + bool(data.get("disable_databricks_connect")) ): - if not values.get("databricks_connect_access_token"): - values["databricks_connect_access_token"] = access_token - if not values.get("databricks_connect_server_hostname"): - values["databricks_connect_server_hostname"] = f"https://{server_hostname}" + if not data.get("databricks_connect_access_token"): + data["databricks_connect_access_token"] = access_token + if not data.get("databricks_connect_server_hostname"): + data["databricks_connect_server_hostname"] = f"https://{server_hostname}" if not databricks_connect_use_serverless: - if not values.get("databricks_connect_cluster_id"): + if not data.get("databricks_connect_cluster_id"): if t.TYPE_CHECKING: assert http_path is not None - values["databricks_connect_cluster_id"] = http_path.split("/")[-1] + data["databricks_connect_cluster_id"] = http_path.split("/")[-1] if auth_type: from databricks.sql.auth.auth import AuthType - all_values = [m.value for m in AuthType] - if auth_type not in all_values: + all_data = [m.value for m in AuthType] + if auth_type not in all_data: raise ValueError( - f"`auth_type` {auth_type} does not match a valid option: {all_values}" + f"`auth_type` {auth_type} does not match a valid option: {all_data}" ) - client_id = values.get("oauth_client_id") - client_secret = values.get("oauth_client_secret") + client_id = data.get("oauth_client_id") + client_secret = data.get("oauth_client_secret") if client_secret and not client_id: raise ValueError( @@ -689,7 +690,7 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str if not http_path: raise ValueError("`http_path` is still required when using `auth_type`") - return values + return data @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -866,26 +867,24 @@ class BigQueryConnectionConfig(ConnectionConfig): type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery") @field_validator("execution_project") - @field_validator_v1_args def validate_execution_project( cls, v: t.Optional[str], - values: t.Dict[str, t.Any], + info: ValidationInfo, ) -> t.Optional[str]: - if v and not values.get("project"): + if v and not info.data.get("project"): raise ConfigError( "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location." ) return v @field_validator("quota_project") - @field_validator_v1_args def validate_quota_project( cls, v: t.Optional[str], - values: t.Dict[str, t.Any], + info: ValidationInfo, ) -> t.Optional[str]: - if v and not values.get("project"): + if v and not info.data.get("project"): raise ConfigError( "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location." ) @@ -998,12 +997,13 @@ class GCPPostgresConnectionConfig(ConnectionConfig): pre_ping: bool = True @model_validator(mode="before") - @model_validator_v1_args - def _validate_auth_method( - cls, values: t.Dict[str, t.Optional[str]] - ) -> t.Dict[str, t.Optional[str]]: - password = values.get("password") - enable_iam_auth = values.get("enable_iam_auth") + def _validate_auth_method(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + password = data.get("password") + enable_iam_auth = data.get("enable_iam_auth") + if password and enable_iam_auth: raise ConfigError( "Invalid GCP Postgres connection configuration - both password and" @@ -1016,7 +1016,8 @@ def _validate_auth_method( " for a postgres user account or enable_iam_auth set to 'True'" " for an IAM user account." ) - return values + + return data @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -1437,40 +1438,37 @@ class TrinoConnectionConfig(ConnectionConfig): type_: t.Literal["trino"] = Field(alias="type", default="trino") @model_validator(mode="after") - @model_validator_v1_args - def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - port = values.get("port") - if ( - values["http_scheme"] == "http" - and not values["method"].is_no_auth - and not values["method"].is_basic - ): + def _root_validator(self) -> Self: + port = self.port + if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic: raise ConfigError("HTTP scheme can only be used with no-auth or basic method") + if port is None: - values["port"] = 80 if values["http_scheme"] == "http" else 443 - if (values["method"].is_ldap or values["method"].is_basic) and ( - not values["password"] or not values["user"] - ): + self.port = 80 if self.http_scheme == "http" else 443 + + if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user): raise ConfigError( - f"Username and Password must be provided if using {values['method'].value} authentication" + f"Username and Password must be provided if using {self.method.value} authentication" ) - if values["method"].is_kerberos and ( - not values["principal"] or not values["keytab"] or not values["krb5_config"] + + if self.method.is_kerberos and ( + not self.principal or not self.keytab or not self.krb5_config ): raise ConfigError( "Kerberos requires the following fields: principal, keytab, and krb5_config" ) - if values["method"].is_jwt and not values["jwt_token"]: + + if self.method.is_jwt and not self.jwt_token: raise ConfigError("JWT requires `jwt_token` to be set") - if values["method"].is_certificate and ( - not values["cert"] - or not values["client_certificate"] - or not values["client_private_key"] + + if self.method.is_certificate and ( + not self.cert or not self.client_certificate or not self.client_private_key ): raise ConfigError( "Certificate requires the following fields: cert, client_certificate, and client_private_key" ) - return values + + return self @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -1677,26 +1675,23 @@ class AthenaConnectionConfig(ConnectionConfig): type_: t.Literal["athena"] = Field(alias="type", default="athena") @model_validator(mode="after") - @model_validator_v1_args - def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - work_group = values.get("work_group") - s3_staging_dir = values.get("s3_staging_dir") - s3_warehouse_location = values.get("s3_warehouse_location") + def _root_validator(self) -> Self: + work_group = self.work_group + s3_staging_dir = self.s3_staging_dir + s3_warehouse_location = self.s3_warehouse_location if not work_group and not s3_staging_dir: raise ConfigError("At least one of work_group or s3_staging_dir must be set") if s3_staging_dir: - values["s3_staging_dir"] = validate_s3_uri( - s3_staging_dir, base=True, error_type=ConfigError - ) + self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError) if s3_warehouse_location: - values["s3_warehouse_location"] = validate_s3_uri( + self.s3_warehouse_location = validate_s3_uri( s3_warehouse_location, base=True, error_type=ConfigError ) - return values + return self @property def _connection_kwargs_keys(self) -> t.Set[str]: diff --git a/sqlmesh/core/config/root.py b/sqlmesh/core/config/root.py index a08b6af83..5f5431601 100644 --- a/sqlmesh/core/config/root.py +++ b/sqlmesh/core/config/root.py @@ -40,11 +40,10 @@ from sqlmesh.core.notification_target import NotificationTarget from sqlmesh.core.user import User from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.pydantic import ( - field_validator, - model_validator, - model_validator_v1_args, -) +from sqlmesh.utils.pydantic import field_validator, model_validator + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import Self logger = logging.getLogger(__name__) @@ -144,7 +143,7 @@ class Config(BaseConfig): _scheduler_config_validator = scheduler_config_validator _variables_validator = variables_validator - @field_validator("gateways", mode="before", always=True) + @field_validator("gateways", mode="before") @classmethod def _gateways_ensure_dict(cls, value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: try: @@ -168,50 +167,57 @@ def _validate_regex_keys( return compiled_regexes @model_validator(mode="before") - @model_validator_v1_args - def _normalize_and_validate_fields(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - if "gateways" not in values and "gateway" in values: - values["gateways"] = values.pop("gateway") + def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + if "gateways" not in data and "gateway" in data: + data["gateways"] = data.pop("gateway") for plan_deprecated in ("auto_categorize_changes", "include_unmodified"): - if plan_deprecated in values: + if plan_deprecated in data: raise ConfigError( f"The `{plan_deprecated}` config is deprecated. Please use the `plan.{plan_deprecated}` config instead." ) - if "physical_schema_override" in values: + if "physical_schema_override" in data: logger.warning( "`physical_schema_override` is deprecated. Please use `physical_schema_mapping` instead" ) - if "physical_schema_mapping" in values: + if "physical_schema_mapping" in data: raise ConfigError( "Only one of `physical_schema_override` and `physical_schema_mapping` can be specified" ) - physical_schema_override: t.Dict[str, str] = values.pop("physical_schema_override") + physical_schema_override: t.Dict[str, str] = data.pop("physical_schema_override") # translate physical_schema_override to physical_schema_mapping - values["physical_schema_mapping"] = { + data["physical_schema_mapping"] = { f"^{k}$": v for k, v in physical_schema_override.items() } - return values + return data @model_validator(mode="after") - @model_validator_v1_args - def _normalize_fields_after(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - dialect = values["model_defaults"].dialect + def _normalize_fields_after(self) -> Self: + dialect = self.model_defaults.dialect def _normalize_identifiers(key: str) -> None: - values[key] = { - k: normalize_identifiers(v, dialect=dialect).name - for k, v in values.get(key, {}).items() - } + setattr( + self, + key, + { + k: normalize_identifiers(v, dialect=dialect).name + for k, v in getattr(self, key, {}).items() + }, + ) - _normalize_identifiers("environment_catalog_mapping") - _normalize_identifiers("physical_schema_mapping") + if self.environment_catalog_mapping: + _normalize_identifiers("environment_catalog_mapping") + if self.physical_schema_mapping: + _normalize_identifiers("physical_schema_mapping") - return values + return self def get_default_test_connection( self, diff --git a/sqlmesh/core/config/scheduler.py b/sqlmesh/core/config/scheduler.py index 411e760f8..7ad4f75f2 100644 --- a/sqlmesh/core/config/scheduler.py +++ b/sqlmesh/core/config/scheduler.py @@ -23,7 +23,7 @@ from sqlmesh.schedulers.airflow.mwaa_client import MWAAClient from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.hashing import md5 -from sqlmesh.utils.pydantic import model_validator, model_validator_v1_args, field_validator +from sqlmesh.utils.pydantic import model_validator, field_validator if t.TYPE_CHECKING: from google.auth.transport.requests import AuthorizedSession @@ -381,15 +381,18 @@ def get_client(self, console: t.Optional[Console] = None) -> AirflowClient: ) @model_validator(mode="before") - @model_validator_v1_args - def check_supported_fields(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + def check_supported_fields(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + allowed_field_names = {field.alias or name for name, field in cls.all_field_infos().items()} allowed_field_names.add("session") - for field_name in values: + for field_name in data: if field_name not in allowed_field_names: raise ValueError(f"Unsupported Field: {field_name}") - return values + + return data class MWAASchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig): diff --git a/sqlmesh/core/metric/definition.py b/sqlmesh/core/metric/definition.py index 10da247b8..dd11cfd38 100644 --- a/sqlmesh/core/metric/definition.py +++ b/sqlmesh/core/metric/definition.py @@ -10,11 +10,7 @@ from sqlmesh.core.node import str_or_exp_to_str from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.pydantic import ( - PydanticModel, - field_validator, - field_validator_v1_args, -) +from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator MeasureAndDimTables = t.Tuple[str, t.Tuple[str, ...]] @@ -83,7 +79,7 @@ class MetricMeta(PydanticModel, frozen=True): @field_validator("name", mode="before") @classmethod def _name_validator(cls, v: t.Any) -> str: - return cls._string_validator(v).lower() + return (cls._string_validator(v) or "").lower() @field_validator("dialect", "owner", "description", mode="before") @classmethod @@ -91,14 +87,9 @@ def _string_validator(cls, v: t.Any) -> t.Optional[str]: return str_or_exp_to_str(v) @field_validator("expression", mode="before") - @field_validator_v1_args - def _validate_expression( - cls, - v: t.Any, - values: t.Dict[str, t.Any], - ) -> exp.Expression: + def _validate_expression(cls, v: t.Any, info: ValidationInfo) -> exp.Expression: if isinstance(v, str): - dialect = values.get("dialect") + dialect = info.data.get("dialect") return d.parse_one(v, dialect=dialect) if isinstance(v, exp.Expression): return v diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 73b065574..032a90e99 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -14,7 +14,7 @@ from sqlmesh.utils import str_to_bool from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error from sqlmesh.utils.metaprogramming import Executable, build_env, prepare_env, serialize_env -from sqlmesh.utils.pydantic import field_validator, field_validator_v1_args +from sqlmesh.utils.pydantic import ValidationInfo, field_validator if t.TYPE_CHECKING: from sqlmesh.utils.jinja import MacroReference @@ -194,11 +194,10 @@ def single_value_or_tuple(values: t.Sequence) -> exp.Identifier | exp.Tuple: ) -@field_validator_v1_args def parse_expression( cls: t.Type, v: t.Union[t.List[str], t.List[exp.Expression], str, exp.Expression, t.Callable, None], - values: t.Dict[str, t.Any], + info: t.Optional[ValidationInfo] = None, ) -> t.List[exp.Expression] | exp.Expression | t.Callable | None: """Helper method to deserialize SQLGlot expressions in Pydantic Models.""" if v is None: @@ -207,7 +206,7 @@ def parse_expression( if callable(v): return v - dialect = values.get("dialect") + dialect = info.data.get("dialect") if info else "" if isinstance(v, list): return [ @@ -231,12 +230,14 @@ def parse_bool(v: t.Any) -> bool: return str_to_bool(str(v or "")) -@field_validator_v1_args -def parse_properties(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[exp.Tuple]: +def parse_properties( + cls: t.Type, v: t.Any, info: t.Optional[ValidationInfo] = None +) -> t.Optional[exp.Tuple]: if v is None: return v - dialect = values.get("dialect") + dialect = info.data.get("dialect") if info else "" + if isinstance(v, str): v = d.parse_one(v, dialect=dialect) if isinstance(v, (exp.Array, exp.Paren, exp.Tuple)): @@ -272,10 +273,9 @@ def default_catalog(cls: t.Type, v: t.Any) -> t.Optional[str]: return str(v) -@field_validator_v1_args -def depends_on(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[t.Set[str]]: - dialect = values.get("dialect") - default_catalog = values.get("default_catalog") +def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[str]]: + dialect = info.data.get("dialect") + default_catalog = info.data.get("default_catalog") if isinstance(v, exp.Paren): v = v.unnest() diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 69141e669..a33846c54 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -2,7 +2,6 @@ import json import logging -import sys import types import re import typing as t @@ -53,17 +52,13 @@ if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType - from sqlmesh.core._typing import TableName + from sqlmesh.core._typing import Self, TableName from sqlmesh.core.context import ExecutionContext from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter._typing import QueryOrDF from sqlmesh.core.snapshot import DeployabilityIndex, Node, Snapshot from sqlmesh.utils.jinja import MacroReference - if sys.version_info >= (3, 11): - from typing import Self - else: - from typing_extensions import Self logger = logging.getLogger(__name__) diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index 72d15a069..648ca1679 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -23,9 +23,9 @@ SQLGlotPositiveInt, SQLGlotString, SQLGlotCron, + ValidationInfo, column_validator, field_validator, - field_validator_v1_args, get_dialect, validate_string, ) @@ -237,49 +237,6 @@ class TimeColumn(PydanticModel): column: exp.Expression format: t.Optional[str] = None - @classmethod - def validator(cls) -> classmethod: - def _time_column_validator(v: t.Any, values: t.Any) -> TimeColumn: - dialect = get_dialect(values) - - if isinstance(v, exp.Tuple): - column_expr = v.expressions[0] - column = ( - exp.column(column_expr) - if isinstance(column_expr, exp.Identifier) - else column_expr - ) - format = v.expressions[1].name if len(v.expressions) > 1 else None - elif isinstance(v, exp.Expression): - column = exp.column(v) if isinstance(v, exp.Identifier) else v - format = None - elif isinstance(v, str): - column = d.parse_one(v, dialect=dialect) - column.meta.pop("sql") - format = None - elif isinstance(v, dict): - column_raw = v["column"] - column = ( - d.parse_one(column_raw, dialect=dialect) - if isinstance(column_raw, str) - else column_raw - ) - format = v.get("format") - elif isinstance(v, TimeColumn): - column = v.column - format = v.format - else: - raise ConfigError(f"Invalid time_column: '{v}'.") - - column = quote_identifiers( - normalize_identifiers(column, dialect=dialect), dialect=dialect - ) - column.meta["dialect"] = dialect - - return TimeColumn(column=column, format=format) - - return field_validator("time_column", mode="before")(_time_column_validator) - @field_validator("column", mode="before") @classmethod def _column_validator(cls, v: t.Union[str, exp.Expression]) -> exp.Expression: @@ -321,9 +278,7 @@ def _kind_dialect_validator(cls: t.Type, v: t.Optional[str]) -> str: return v -kind_dialect_validator = field_validator("dialect", mode="before", always=True)( - _kind_dialect_validator -) +kind_dialect_validator = field_validator("dialect", mode="before")(_kind_dialect_validator) class _Incremental(_ModelKind): @@ -409,7 +364,42 @@ class IncrementalByTimeRangeKind(_IncrementalBy): time_column: TimeColumn auto_restatement_intervals: t.Optional[SQLGlotPositiveInt] = None - _time_column_validator = TimeColumn.validator() + @field_validator("time_column", mode="before") + @classmethod + def _time_column_validator(cls, v: t.Any, values: t.Any) -> TimeColumn: + dialect = get_dialect(values) + + if isinstance(v, exp.Tuple): + column_expr = v.expressions[0] + column = ( + exp.column(column_expr) if isinstance(column_expr, exp.Identifier) else column_expr + ) + format = v.expressions[1].name if len(v.expressions) > 1 else None + elif isinstance(v, exp.Expression): + column = exp.column(v) if isinstance(v, exp.Identifier) else v + format = None + elif isinstance(v, str): + column = d.parse_one(v, dialect=dialect) + column.meta.pop("sql") + format = None + elif isinstance(v, dict): + column_raw = v["column"] + column = ( + d.parse_one(column_raw, dialect=dialect) + if isinstance(column_raw, str) + else column_raw + ) + format = v.get("format") + elif isinstance(v, TimeColumn): + column = v.column + format = v.format + else: + raise ConfigError(f"Invalid time_column: '{v}'.") + + column = quote_identifiers(normalize_identifiers(column, dialect=dialect), dialect=dialect) + column.meta["dialect"] = dialect + + return TimeColumn(column=column, format=format) def to_expression( self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any @@ -450,11 +440,10 @@ class IncrementalByUniqueKeyKind(_IncrementalBy): batch_concurrency: t.Literal[1] = 1 @field_validator("when_matched", mode="before") - @field_validator_v1_args def _when_matched_validator( cls, v: t.Optional[t.Union[str, exp.Whens]], - values: t.Dict[str, t.Any], + info: ValidationInfo, ) -> t.Optional[exp.Whens]: if v is None: return v @@ -464,22 +453,21 @@ def _when_matched_validator( if v.startswith("("): v = v[1:-1] - return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(values))) + return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(info.data))) return t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases)) @field_validator("merge_filter", mode="before") - @field_validator_v1_args def _merge_filter_validator( cls, v: t.Optional[exp.Expression], - values: t.Dict[str, t.Any], + info: ValidationInfo, ) -> t.Optional[exp.Expression]: if v is None: return v if isinstance(v, str): v = v.strip() - return d.parse_one(v, dialect=get_dialect(values)) + return d.parse_one(v, dialect=get_dialect(info.data)) return v.transform(d.replace_merge_table_aliases) @@ -616,7 +604,7 @@ def _parse_csv_settings(cls, v: t.Any) -> t.Optional[CsvSettings]: if v is None or isinstance(v, CsvSettings): return v if isinstance(v, exp.Expression): - tuple_exp = parse_properties(cls, v, {}) + tuple_exp = parse_properties(cls, v) if not tuple_exp: return None return CsvSettings(**{e.left.name: e.right for e in tuple_exp.expressions}) @@ -669,13 +657,11 @@ class _SCDType2Kind(_Incremental): _dialect_validator = kind_dialect_validator - # Remove once Pydantic 1 is deprecated - _always_validate_column = field_validator( - "valid_from_name", "valid_to_name", mode="before", always=True - )(column_validator) + _always_validate_column = field_validator("valid_from_name", "valid_to_name", mode="before")( + column_validator + ) - # always=True can be removed once Pydantic 1 is deprecated - @field_validator("time_data_type", mode="before", always=True) + @field_validator("time_data_type", mode="before") @classmethod def _time_data_type_validator( cls, v: t.Union[str, exp.Expression], values: t.Any @@ -742,8 +728,7 @@ class SCDType2ByTimeKind(_SCDType2Kind): updated_at_name: SQLGlotColumn = Field(exp.column("updated_at"), validate_default=True) updated_at_as_valid_from: SQLGlotBool = False - # Remove once Pydantic 1 is deprecated - _always_validate_updated_at = field_validator("updated_at_name", mode="before", always=True)( + _always_validate_updated_at = field_validator("updated_at_name", mode="before")( column_validator ) @@ -986,9 +971,10 @@ def create_model_kind(v: t.Any, dialect: str, defaults: t.Dict[str, t.Any]) -> M return model_kind_type_from_name(name)(name=name) # type: ignore -@field_validator_v1_args -def _model_kind_validator(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> ModelKind: - dialect = get_dialect(values) +def _model_kind_validator( + cls: t.Type, v: t.Any, info: t.Optional[ValidationInfo] = None +) -> ModelKind: + dialect = get_dialect(info.data) if info else "" return create_model_kind(v, dialect, {}) diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index bf5e85211..a15ef229b 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -3,6 +3,7 @@ import logging import typing as t from functools import cached_property +from typing_extensions import Self from pydantic import Field from sqlglot import Dialect, exp @@ -34,11 +35,10 @@ from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import ( + ValidationInfo, field_validator, - field_validator_v1_args, list_of_fields_validator, model_validator, - model_validator_v1_args, ) if t.TYPE_CHECKING: @@ -124,15 +124,14 @@ def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any: return v or [] @field_validator("tags", mode="before") - @field_validator_v1_args - def _value_or_tuple_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any: - return ensure_list(cls._validate_value_or_tuple(v, values)) + def _value_or_tuple_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any: + return ensure_list(cls._validate_value_or_tuple(v, info.data)) @classmethod def _validate_value_or_tuple( - cls, v: t.Dict[str, t.Any], values: t.Dict[str, t.Any], normalize: bool = False + cls, v: t.Dict[str, t.Any], data: t.Dict[str, t.Any], normalize: bool = False ) -> t.Any: - dialect = values.get("dialect") + dialect = data.get("dialect") def _normalize(value: t.Any) -> t.Any: return normalize_identifiers(value, dialect=dialect) if normalize else value @@ -148,15 +147,14 @@ def _normalize(value: t.Any) -> t.Any: value = _normalize(v) return value.name if isinstance(value, exp.Expression) else value if isinstance(v, (list, tuple)): - return [cls._validate_value_or_tuple(elm, values, normalize=normalize) for elm in v] + return [cls._validate_value_or_tuple(elm, data, normalize=normalize) for elm in v] return v @field_validator("table_format", "storage_format", mode="before") - @field_validator_v1_args - def _format_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[str]: + def _format_validator(cls, v: t.Any, info: ValidationInfo) -> t.Optional[str]: if isinstance(v, exp.Expression) and not (isinstance(v, (exp.Literal, exp.Identifier))): - return v.sql(values.get("dialect")) + return v.sql(info.data.get("dialect")) return str_or_exp_to_str(v) @field_validator("dialect", mode="before") @@ -180,11 +178,10 @@ def _gateway_validator(cls, v: t.Any) -> t.Optional[str]: return gateway and gateway.lower() @field_validator("partitioned_by_", "clustered_by", mode="before") - @field_validator_v1_args def _partition_and_cluster_validator( - cls, v: t.Any, values: t.Dict[str, t.Any] + cls, v: t.Any, info: ValidationInfo ) -> t.List[exp.Expression]: - expressions = list_of_fields_validator(v, values) + expressions = list_of_fields_validator(v, info.data) for expression in expressions: num_cols = len(list(expression.find_all(exp.Column))) @@ -203,12 +200,11 @@ def _partition_and_cluster_validator( @field_validator( "columns_to_types_", "derived_columns_to_types", mode="before", check_fields=False ) - @field_validator_v1_args def _columns_validator( - cls, v: t.Any, values: t.Dict[str, t.Any] + cls, v: t.Any, info: ValidationInfo ) -> t.Optional[t.Dict[str, exp.DataType]]: columns_to_types = {} - dialect = values.get("dialect") + dialect = info.data.get("dialect") if isinstance(v, exp.Schema): for column in v.expressions: @@ -230,11 +226,10 @@ def _columns_validator( return v @field_validator("column_descriptions_", mode="before") - @field_validator_v1_args def _column_descriptions_validator( - cls, vs: t.Any, values: t.Dict[str, t.Any] + cls, vs: t.Any, info: ValidationInfo ) -> t.Optional[t.Dict[str, str]]: - dialect = values.get("dialect") + dialect = info.data.get("dialect") if vs is None: return None @@ -254,20 +249,19 @@ def _column_descriptions_validator( for k, v in raw_col_descriptions.items() } - columns_to_types = values.get("columns_to_types_") + columns_to_types = info.data.get("columns_to_types_") if columns_to_types: for column_name in col_descriptions: if column_name not in columns_to_types: raise ConfigError( - f"In model '{values['name']}', a description is provided for column '{column_name}' but it is not a column in the model." + f"In model '{info.data['name']}', a description is provided for column '{column_name}' but it is not a column in the model." ) return col_descriptions @field_validator("grains", "references", mode="before") - @field_validator_v1_args - def _refs_validator(cls, vs: t.Any, values: t.Dict[str, t.Any]) -> t.List[exp.Expression]: - dialect = values.get("dialect") + def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expression]: + dialect = info.data.get("dialect") if isinstance(vs, exp.Paren): vs = vs.unnest() @@ -290,66 +284,64 @@ def _refs_validator(cls, vs: t.Any, values: t.Dict[str, t.Any]) -> t.List[exp.Ex return refs @model_validator(mode="before") - @model_validator_v1_args - def _pre_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - grain = values.pop("grain", None) + def _pre_root_validator(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + grain = data.pop("grain", None) if grain: - grains = values.get("grains") + grains = data.get("grains") if grains: raise ConfigError( f"Cannot use argument 'grain' ({grain}) with 'grains' ({grains}), use only grains" ) - values["grains"] = ensure_list(grain) + data["grains"] = ensure_list(grain) - table_properties = values.pop("table_properties", None) + table_properties = data.pop("table_properties", None) if table_properties: if not isinstance(table_properties, str): # Do not warn when deserializing from the state. - model_name = values["name"] + model_name = data["name"] logger.warning( f"Model '{model_name}' is using the `table_properties` attribute which is deprecated. Please use `physical_properties` instead." ) - physical_properties = values.get("physical_properties") + physical_properties = data.get("physical_properties") if physical_properties: raise ConfigError( f"Cannot use argument 'table_properties' ({table_properties}) with 'physical_properties' ({physical_properties}), use only physical_properties." ) - values["physical_properties"] = table_properties - return values + + data["physical_properties"] = table_properties + + return data @model_validator(mode="after") - @model_validator_v1_args - def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - values = cls._kind_validator(values) + def _root_validator(self) -> Self: + kind: t.Any = self.kind + + for field in ("partitioned_by_", "clustered_by"): + if ( + getattr(self, field, None) + and not kind.is_materialized + and not (kind.is_view and kind.materialized) + ): + name = field[:-1] if field.endswith("_") else field + raise ValueError(f"{name} field cannot be set for {kind} models") + if kind.is_incremental_by_partition and not getattr(self, "partitioned_by_", None): + raise ValueError(f"partitioned_by field is required for {kind.name} models") # needs to be in a mode=after model validator so that the field validators have run to convert from Expression -> str - if (storage_format := values.get("storage_format")) and storage_format.lower() in { + if (storage_format := self.storage_format) and storage_format.lower() in { "iceberg", "hive", "hudi", "delta", }: logger.warning( - f"Model {values['name']} has `storage_format` set to a table format '{storage_format}' which is deprecated. Please use the `table_format` property instead" + f"Model {self.name} has `storage_format` set to a table format '{storage_format}' which is deprecated. Please use the `table_format` property instead" ) - return values - - @classmethod - def _kind_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - kind = values.get("kind") - if kind: - for field in ("partitioned_by_", "clustered_by"): - if ( - values.get(field) - and not kind.is_materialized - and not (kind.is_view and kind.materialized) - ): - name = field[:-1] if field.endswith("_") else field - raise ValueError(f"{name} field cannot be set for {kind} models") - if kind.is_incremental_by_partition and not values.get("partitioned_by_"): - raise ValueError(f"partitioned_by field is required for {kind.name} models") - return values + return self @property def time_column(self) -> t.Optional[TimeColumn]: diff --git a/sqlmesh/core/node.py b/sqlmesh/core/node.py index 6f2ff85b0..82cfa170f 100644 --- a/sqlmesh/core/node.py +++ b/sqlmesh/core/node.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys import typing as t from datetime import datetime from enum import Enum @@ -17,18 +16,13 @@ SQLGlotCron, field_validator, model_validator, - model_validator_v1_args, PRIVATE_FIELDS, ) if t.TYPE_CHECKING: + from sqlmesh.core._typing import Self from sqlmesh.core.snapshot import Node - if sys.version_info >= (3, 11): - from typing import Self - else: - from typing_extensions import Self - class IntervalUnit(str, Enum): """IntervalUnit is the inferred granularity of an incremental node. @@ -257,22 +251,24 @@ def _interval_unit_validator(cls, v: t.Any) -> t.Optional[t.Union[IntervalUnit, return v @model_validator(mode="after") - @model_validator_v1_args - def _node_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - interval_unit = values.get("interval_unit_") - if interval_unit and not values.get("allow_partials"): - cron = values["cron"] + def _node_root_validator(self) -> Self: + interval_unit = self.interval_unit_ + if interval_unit and not getattr(self, "allow_partials", None): + cron = self.cron max_interval_unit = IntervalUnit.from_cron(cron) if interval_unit.seconds > max_interval_unit.seconds: raise ConfigError( - f"Cron '{cron}' cannot be more frequent than interval unit '{interval_unit.value}'. If this is intentional, set allow_partials to True." + f"Cron '{cron}' cannot be more frequent than interval unit '{interval_unit.value}'. " + "If this is intentional, set allow_partials to True." ) - start = values.get("start") - end = values.get("end") + + start = self.start + end = self.end + if end is not None and start is None: raise ConfigError("Must define a start date if an end date is defined.") validate_date_range(start, end) - return values + return self @property def batch_size(self) -> t.Optional[int]: diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 695624311..9aeef86ec 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -23,11 +23,7 @@ from sqlmesh.utils import major_minor from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import SQLMeshError -from sqlmesh.utils.pydantic import ( - PydanticModel, - field_validator, - field_validator_v1_args, -) +from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator logger = logging.getLogger(__name__) @@ -71,11 +67,10 @@ class PromotionResult(PydanticModel): removed_environment_naming_info: t.Optional[EnvironmentNamingInfo] @field_validator("removed_environment_naming_info") - @field_validator_v1_args def _validate_removed_environment_naming_info( - cls, v: t.Optional[EnvironmentNamingInfo], values: t.Any + cls, v: t.Optional[EnvironmentNamingInfo], info: ValidationInfo ) -> t.Optional[EnvironmentNamingInfo]: - if v and not values["removed"]: + if v and not info.data.get("removed"): raise ValueError("removed_environment_naming_info must be None if removed is empty") return v diff --git a/sqlmesh/core/user.py b/sqlmesh/core/user.py index ad6a3221c..fabc06516 100644 --- a/sqlmesh/core/user.py +++ b/sqlmesh/core/user.py @@ -1,15 +1,8 @@ import typing as t from enum import Enum -from sqlmesh.core.notification_target import ( - BasicSMTPNotificationTarget, - NotificationTarget, -) -from sqlmesh.utils.pydantic import ( - PydanticModel, - field_validator, - field_validator_v1_args, -) +from sqlmesh.core.notification_target import BasicSMTPNotificationTarget, NotificationTarget +from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator class UserRole(str, Enum): @@ -44,13 +37,12 @@ def is_required_approver(self) -> bool: return UserRole.REQUIRED_APPROVER in self.roles @field_validator("notification_targets") - @field_validator_v1_args def validate_notification_targets( cls, v: t.List[NotificationTarget], - values: t.Dict[str, t.Any], + info: ValidationInfo, ) -> t.List[NotificationTarget]: - email = values["email"] + email = info.data["email"] for target in v: if isinstance(target, BasicSMTPNotificationTarget) and target.recipients != {email}: raise ValueError("Recipient emails do not match user email") diff --git a/sqlmesh/dbt/target.py b/sqlmesh/dbt/target.py index f5ed4576f..86a57dfdd 100644 --- a/sqlmesh/dbt/target.py +++ b/sqlmesh/dbt/target.py @@ -34,11 +34,7 @@ from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.utils import AttributeDict, classproperty from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.pydantic import ( - field_validator, - model_validator, - model_validator_v1_args, -) +from sqlmesh.utils.pydantic import field_validator, model_validator logger = logging.getLogger(__name__) @@ -167,20 +163,22 @@ class DuckDbConfig(TargetConfig): settings: t.Optional[t.Dict[str, t.Any]] = None @model_validator(mode="before") - @model_validator_v1_args - def validate_authentication( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - if "database" not in values and DBT_VERSION >= (1, 5): - path = values.get("path") - values["database"] = ( + def validate_authentication(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + if "database" not in data and DBT_VERSION >= (1, 5): + path = data.get("path") + data["database"] = ( "memory" if path is None or path == DUCKDB_IN_MEMORY else Path(t.cast(str, path)).stem ) - if "threads" in values and t.cast(int, values["threads"]) > 1: + + if "threads" in data and t.cast(int, data["threads"]) > 1: logger.warning("DuckDB does not support concurrency - setting threads to 1.") - return values + + return data def default_incremental_strategy(self, kind: IncrementalKind) -> str: return "delete+insert" @@ -257,17 +255,17 @@ class SnowflakeConfig(TargetConfig): retry_all: bool = False @model_validator(mode="before") - @model_validator_v1_args - def validate_authentication( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - if ( - values.get("password") - or values.get("authenticator") - or values.get("private_key") - or values.get("private_key_path") - ): - return values + @classmethod + def validate_authentication(cls, data: t.Any) -> t.Any: + if isinstance(data, dict): + if ( + data.get("password") + or data.get("authenticator") + or data.get("private_key") + or data.get("private_key_path") + ): + return data + raise ConfigError("No supported Snowflake authentication method found in target profile.") def default_incremental_strategy(self, kind: IncrementalKind) -> str: @@ -339,14 +337,16 @@ class PostgresConfig(TargetConfig): sslmode: t.Optional[str] = None @model_validator(mode="before") - @model_validator_v1_args - def validate_database( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - values["database"] = values.get("database") or values.get("dbname") - if not values["database"]: + @classmethod + def validate_database(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + data["database"] = data.get("database") or data.get("dbname") + if not data["database"]: raise ConfigError("Either database or dbname must be set") - return values + + return data @field_validator("port") @classmethod @@ -401,14 +401,13 @@ class RedshiftConfig(TargetConfig): sslmode: t.Optional[str] = None @model_validator(mode="before") - @model_validator_v1_args - def validate_database( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - values["database"] = values.get("database") or values.get("dbname") - if not values["database"]: - raise ConfigError("Either database or dbname must be set") - return values + @classmethod + def validate_database(cls, data: t.Any) -> t.Any: + if isinstance(data, dict): + data["database"] = data.get("database") or data.get("dbname") + if not data["database"]: + raise ConfigError("Either database or dbname must be set") + return data def default_incremental_strategy(self, kind: IncrementalKind) -> str: return "append" @@ -546,17 +545,19 @@ class BigQueryConfig(TargetConfig): maximum_bytes_billed: t.Optional[int] = None @model_validator(mode="before") - @model_validator_v1_args - def validate_fields( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - values["schema"] = values.get("schema") or values.get("dataset") - if not values["schema"]: + @classmethod + def validate_fields(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + data["schema"] = data.get("schema") or data.get("dataset") + if not data["schema"]: raise ConfigError("Either schema or dataset must be set") - values["database"] = values.get("database") or values.get("project") - if not values["database"]: + data["database"] = data.get("database") or data.get("project") + if not data["database"]: raise ConfigError("Either database or project must be set") - return values + + return data def default_incremental_strategy(self, kind: IncrementalKind) -> str: return "merge" @@ -661,23 +662,24 @@ class MSSQLConfig(TargetConfig): client_secret: t.Optional[str] = None # Azure Active Directory auth @model_validator(mode="before") - @model_validator_v1_args - def validate_alias_fields( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - values["host"] = values.get("host") or values.get("server") - if not values["host"]: + @classmethod + def validate_alias_fields(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + data["host"] = data.get("host") or data.get("server") + if not data["host"]: raise ConfigError("Either host or server must be set") - values["user"] = values.get("user") or values.get("username") or values.get("UID") - if not values["user"]: + data["user"] = data.get("user") or data.get("username") or data.get("UID") + if not data["user"]: raise ConfigError("One of user, username, or UID must be set") - values["password"] = values.get("password") or values.get("PWD") - if not values["password"]: + data["password"] = data.get("password") or data.get("PWD") + if not data["password"]: raise ConfigError("Either password or PWD must be set") - return values + return data @field_validator("authentication") @classmethod diff --git a/sqlmesh/integrations/github/cicd/config.py b/sqlmesh/integrations/github/cicd/config.py index b543331b2..04ce2337b 100644 --- a/sqlmesh/integrations/github/cicd/config.py +++ b/sqlmesh/integrations/github/cicd/config.py @@ -6,7 +6,7 @@ from sqlmesh.core.config import CategorizerConfig from sqlmesh.core.config.base import BaseConfig from sqlmesh.utils.date import TimeLike -from sqlmesh.utils.pydantic import model_validator, model_validator_v1_args +from sqlmesh.utils.pydantic import model_validator class MergeMethod(str, Enum): @@ -30,13 +30,15 @@ class GithubCICDBotConfig(BaseConfig): pr_environment_name: t.Optional[str] = None @model_validator(mode="before") - @model_validator_v1_args - def _validate(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - if values.get("enable_deploy_command") and not values.get("merge_method"): - raise ValueError("merge_method must be set if enable_deploy_command is True") - if values.get("command_namespace") and not values.get("enable_deploy_command"): - raise ValueError("enable_deploy_command must be set if command_namespace is set") - return values + @classmethod + def _validate(cls, data: t.Any) -> t.Any: + if isinstance(data, dict): + if data.get("enable_deploy_command") and not data.get("merge_method"): + raise ValueError("merge_method must be set if enable_deploy_command is True") + if data.get("command_namespace") and not data.get("enable_deploy_command"): + raise ValueError("enable_deploy_command must be set if command_namespace is set") + + return data FIELDS_FOR_ANALYTICS: t.ClassVar[t.Set[str]] = { "invalidate_environment_after_deploy", diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index d28b5aa13..f710a934d 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -2,9 +2,9 @@ import json import typing as t -from functools import wraps import pydantic +from pydantic import ValidationInfo as ValidationInfo from pydantic.fields import FieldInfo from sqlglot import exp, parse_one from sqlglot.helper import ensure_list @@ -27,14 +27,10 @@ def field_validator(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: - # Pydantic v2 doesn't support "always" argument. The validator behaves as if "always" is True. - kwargs.pop("always", None) return pydantic.field_validator(*args, **kwargs) def model_validator(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: - # Pydantic v2 doesn't support "always" argument. The validator behaves as if "always" is True. - kwargs.pop("always", None) return pydantic.model_validator(*args, **kwargs) @@ -181,30 +177,6 @@ def __repr__(self) -> str: return str(self) -def model_validator_v1_args(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: - @wraps(func) - def wrapper(cls: t.Type, values: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: - is_values_dict = isinstance(values, dict) - values_dict = values if is_values_dict else values.__dict__ - result = func(cls, values_dict, *args, **kwargs) - if is_values_dict: - return result - else: - values.__dict__.update(result) - return values - - return wrapper - - -def field_validator_v1_args(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: - @wraps(func) - def wrapper(cls: t.Type, v: t.Any, values: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: - values_dict = values if isinstance(values, dict) else values.data - return func(cls, v, values_dict, *args, **kwargs) - - return wrapper - - def validate_list_of_strings(v: t.Any) -> t.List[str]: if isinstance(v, exp.Identifier): return [v.name] diff --git a/tests/core/test_model.py b/tests/core/test_model.py index e08cf0d8c..04df69d89 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -2595,7 +2595,7 @@ def test_parse_expression_list_with_jinja(): "JINJA_STATEMENT_BEGIN;\n{{ log('log message') }}\nJINJA_END;", "GRANT SELECT ON TABLE foo TO DEV", ] - assert input == [val.sql() for val in parse_expression(SqlModel, input, {})] + assert input == [val.sql() for val in parse_expression(SqlModel, input)] def test_no_depends_on_runtime_jinja_query(): @@ -3687,7 +3687,7 @@ def test_scd_type_2_by_time_overrides(): assert not scd_type_2_model.kind.disable_restatement model_kind_dict = scd_type_2_model.kind.dict() - assert scd_type_2_model.kind == _model_kind_validator(None, model_kind_dict, {}) + assert scd_type_2_model.kind == _model_kind_validator(None, model_kind_dict) def test_scd_type_2_by_column_defaults(): @@ -3776,7 +3776,7 @@ def test_scd_type_2_by_column_overrides(): assert not scd_type_2_model.kind.disable_restatement model_kind_dict = scd_type_2_model.kind.dict() - assert scd_type_2_model.kind == _model_kind_validator(None, model_kind_dict, {}) + assert scd_type_2_model.kind == _model_kind_validator(None, model_kind_dict) def test_scd_type_2_python_model() -> None: diff --git a/web/server/models.py b/web/server/models.py index 6bfd378eb..ddf1d5dc6 100644 --- a/web/server/models.py +++ b/web/server/models.py @@ -19,11 +19,7 @@ SnapshotId, ) from sqlmesh.utils.date import TimeLike, now_timestamp -from sqlmesh.utils.pydantic import ( - PydanticModel, - field_validator, - field_validator_v1_args, -) +from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator SUPPORTED_EXTENSIONS = {".py", ".sql", ".yaml", ".yml", ".csv"} @@ -119,11 +115,10 @@ class File(PydanticModel): content: t.Optional[str] = None model_config = pydantic.ConfigDict(validate_default=True) # type: ignore - @field_validator("extension", always=True, mode="before") - @field_validator_v1_args - def default_extension(cls, v: str, values: t.Dict[str, t.Any]) -> str: - if "name" in values: - return pathlib.Path(values["name"]).suffix + @field_validator("extension", mode="before") + def default_extension(cls, v: str, info: ValidationInfo) -> str: + if "name" in info.data: + return pathlib.Path(info.data["name"]).suffix return v From ca37e0408e51edeedf741cf55908c3904d8b66d2 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 23 Jan 2025 15:55:06 +0200 Subject: [PATCH 2/5] Adress feedback related to the TimeColumn validation --- sqlmesh/core/model/kind.py | 80 +++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index 648ca1679..b6d2c9d30 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -237,6 +237,49 @@ class TimeColumn(PydanticModel): column: exp.Expression format: t.Optional[str] = None + @classmethod + def validator(cls) -> classmethod: + def _time_column_validator(v: t.Any, info: ValidationInfo) -> TimeColumn: + dialect = get_dialect(info.data) + + if isinstance(v, exp.Tuple): + column_expr = v.expressions[0] + column = ( + exp.column(column_expr) + if isinstance(column_expr, exp.Identifier) + else column_expr + ) + format = v.expressions[1].name if len(v.expressions) > 1 else None + elif isinstance(v, exp.Expression): + column = exp.column(v) if isinstance(v, exp.Identifier) else v + format = None + elif isinstance(v, str): + column = d.parse_one(v, dialect=dialect) + column.meta.pop("sql") + format = None + elif isinstance(v, dict): + column_raw = v["column"] + column = ( + d.parse_one(column_raw, dialect=dialect) + if isinstance(column_raw, str) + else column_raw + ) + format = v.get("format") + elif isinstance(v, TimeColumn): + column = v.column + format = v.format + else: + raise ConfigError(f"Invalid time_column: '{v}'.") + + column = quote_identifiers( + normalize_identifiers(column, dialect=dialect), dialect=dialect + ) + column.meta["dialect"] = dialect + + return TimeColumn(column=column, format=format) + + return field_validator("time_column", mode="before")(_time_column_validator) + @field_validator("column", mode="before") @classmethod def _column_validator(cls, v: t.Union[str, exp.Expression]) -> exp.Expression: @@ -364,42 +407,7 @@ class IncrementalByTimeRangeKind(_IncrementalBy): time_column: TimeColumn auto_restatement_intervals: t.Optional[SQLGlotPositiveInt] = None - @field_validator("time_column", mode="before") - @classmethod - def _time_column_validator(cls, v: t.Any, values: t.Any) -> TimeColumn: - dialect = get_dialect(values) - - if isinstance(v, exp.Tuple): - column_expr = v.expressions[0] - column = ( - exp.column(column_expr) if isinstance(column_expr, exp.Identifier) else column_expr - ) - format = v.expressions[1].name if len(v.expressions) > 1 else None - elif isinstance(v, exp.Expression): - column = exp.column(v) if isinstance(v, exp.Identifier) else v - format = None - elif isinstance(v, str): - column = d.parse_one(v, dialect=dialect) - column.meta.pop("sql") - format = None - elif isinstance(v, dict): - column_raw = v["column"] - column = ( - d.parse_one(column_raw, dialect=dialect) - if isinstance(column_raw, str) - else column_raw - ) - format = v.get("format") - elif isinstance(v, TimeColumn): - column = v.column - format = v.format - else: - raise ConfigError(f"Invalid time_column: '{v}'.") - - column = quote_identifiers(normalize_identifiers(column, dialect=dialect), dialect=dialect) - column.meta["dialect"] = dialect - - return TimeColumn(column=column, format=format) + _time_column_validator = TimeColumn.validator() def to_expression( self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any From 338151114f608f44a620b06bab0f2d60f138e984 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 23 Jan 2025 16:57:32 +0200 Subject: [PATCH 3/5] Address feedback related to dict instance checks --- sqlmesh/dbt/target.py | 26 ++++++++++++---------- sqlmesh/integrations/github/cicd/config.py | 12 +++++----- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/sqlmesh/dbt/target.py b/sqlmesh/dbt/target.py index 86a57dfdd..74e8fcbd6 100644 --- a/sqlmesh/dbt/target.py +++ b/sqlmesh/dbt/target.py @@ -257,14 +257,13 @@ class SnowflakeConfig(TargetConfig): @model_validator(mode="before") @classmethod def validate_authentication(cls, data: t.Any) -> t.Any: - if isinstance(data, dict): - if ( - data.get("password") - or data.get("authenticator") - or data.get("private_key") - or data.get("private_key_path") - ): - return data + if not isinstance(data, dict) or ( + data.get("password") + or data.get("authenticator") + or data.get("private_key") + or data.get("private_key_path") + ): + return data raise ConfigError("No supported Snowflake authentication method found in target profile.") @@ -403,10 +402,13 @@ class RedshiftConfig(TargetConfig): @model_validator(mode="before") @classmethod def validate_database(cls, data: t.Any) -> t.Any: - if isinstance(data, dict): - data["database"] = data.get("database") or data.get("dbname") - if not data["database"]: - raise ConfigError("Either database or dbname must be set") + if not isinstance(data, dict): + return data + + data["database"] = data.get("database") or data.get("dbname") + if not data["database"]: + raise ConfigError("Either database or dbname must be set") + return data def default_incremental_strategy(self, kind: IncrementalKind) -> str: diff --git a/sqlmesh/integrations/github/cicd/config.py b/sqlmesh/integrations/github/cicd/config.py index 04ce2337b..99f98a63a 100644 --- a/sqlmesh/integrations/github/cicd/config.py +++ b/sqlmesh/integrations/github/cicd/config.py @@ -32,11 +32,13 @@ class GithubCICDBotConfig(BaseConfig): @model_validator(mode="before") @classmethod def _validate(cls, data: t.Any) -> t.Any: - if isinstance(data, dict): - if data.get("enable_deploy_command") and not data.get("merge_method"): - raise ValueError("merge_method must be set if enable_deploy_command is True") - if data.get("command_namespace") and not data.get("enable_deploy_command"): - raise ValueError("enable_deploy_command must be set if command_namespace is set") + if not isinstance(data, dict): + return data + + if data.get("enable_deploy_command") and not data.get("merge_method"): + raise ValueError("merge_method must be set if enable_deploy_command is True") + if data.get("command_namespace") and not data.get("enable_deploy_command"): + raise ValueError("enable_deploy_command must be set if command_namespace is set") return data From 8fc85c22303e31619477d30ad5563b5be3e6ef99 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 23 Jan 2025 18:02:35 +0200 Subject: [PATCH 4/5] Temporarily comment out main branch filter in CI to test cloud engines --- .circleci/continue_config.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index 201a7c3a6..ebacc9500 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -313,10 +313,10 @@ workflows: - bigquery - clickhouse-cloud - athena - filters: - branches: - only: - - main + # filters: + # branches: + # only: + # - main - trigger_private_tests: requires: - style_and_slow_tests From 5303811ded803404579f0de9e9e326a73b81e259 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 23 Jan 2025 19:19:50 +0200 Subject: [PATCH 5/5] Revert CI comment, add test for standalone audit non-blocking validation --- .circleci/continue_config.yml | 8 ++++---- tests/core/test_audit.py | 5 +++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index ebacc9500..201a7c3a6 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -313,10 +313,10 @@ workflows: - bigquery - clickhouse-cloud - athena - # filters: - # branches: - # only: - # - main + filters: + branches: + only: + - main - trigger_private_tests: requires: - style_and_slow_tests diff --git a/tests/core/test_audit.py b/tests/core/test_audit.py index 90ee1603d..acb0e6676 100644 --- a/tests/core/test_audit.py +++ b/tests/core/test_audit.py @@ -678,6 +678,11 @@ def test_standalone_audit(model: Model, assert_exp_eq): rendered_query, """SELECT * FROM "db"."test_model" AS "test_model" WHERE "col" IS NULL""" ) + with pytest.raises(AuditConfigError) as ex: + StandaloneAudit(name="test_audit", query=parse_one("SELECT 1"), blocking=True) + + assert "Standalone audits cannot be blocking: 'test_audit'." in str(ex.value) + def test_render_definition(): expressions = parse(