diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 3efc81fa1..b0ab92151 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -61,6 +61,10 @@ class JinjaStatement(Jinja): pass +class VirtualStatement(exp.Expression): + pass + + class ModelKind(exp.Expression): arg_types = {"this": True, "expressions": False} @@ -749,6 +753,8 @@ def _is_command_statement(command: str, tokens: t.List[Token], pos: int) -> bool JINJA_QUERY_BEGIN = "JINJA_QUERY_BEGIN" JINJA_STATEMENT_BEGIN = "JINJA_STATEMENT_BEGIN" JINJA_END = "JINJA_END" +ON_VIRTUAL_UPDATE_BEGIN = "ON_VIRTUAL_UPDATE_BEGIN" +ON_VIRTUAL_UPDATE_END = "ON_VIRTUAL_UPDATE_END" def _is_jinja_statement_begin(tokens: t.List[Token], pos: int) -> bool: @@ -771,10 +777,24 @@ def jinja_statement(statement: str) -> JinjaStatement: return JinjaStatement(this=exp.Literal.string(statement.strip())) +def _is_virtual_statement_begin(tokens: t.List[Token], pos: int) -> bool: + return _is_command_statement(ON_VIRTUAL_UPDATE_BEGIN, tokens, pos) + + +def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool: + return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos) + + +def virtual_statement(statement: exp.Expression) -> VirtualStatement: + return VirtualStatement(this=statement) + + class ChunkType(Enum): JINJA_QUERY = auto() JINJA_STATEMENT = auto() SQL = auto() + VIRTUAL_STATEMENT = auto() + VIRTUAL_JINJA_STATEMENT = auto() def parse_one( @@ -814,9 +834,14 @@ def parse( total = len(tokens) pos = 0 + virtual = False while pos < total: token = tokens[pos] - if _is_jinja_end(tokens, pos) or ( + if _is_virtual_statement_end(tokens, pos): + pos += 2 + virtual = False + chunks.append(([], ChunkType.SQL)) + elif _is_jinja_end(tokens, pos) or ( chunks[-1][1] == ChunkType.SQL and token.token_type == TokenType.SEMICOLON and pos < total - 1 @@ -827,13 +852,32 @@ def parse( # Jinja end statement chunks[-1][0].append(token) pos += 2 - chunks.append(([], ChunkType.SQL)) + if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END: + # This is required for nested Jinja statements that precede + # SQL statements within an ON_VIRTUAL_UPDATE block + chunks.append( + ( + [Token(TokenType.VAR, text=ON_VIRTUAL_UPDATE_BEGIN)], + ChunkType.VIRTUAL_STATEMENT, + ) + ) + else: + chunks.append(([], ChunkType.SQL)) elif _is_jinja_query_begin(tokens, pos): chunks.append(([token], ChunkType.JINJA_QUERY)) pos += 2 elif _is_jinja_statement_begin(tokens, pos): - chunks.append(([token], ChunkType.JINJA_STATEMENT)) + chunks.append( + ( + [token], + ChunkType.VIRTUAL_JINJA_STATEMENT if virtual else ChunkType.JINJA_STATEMENT, + ) + ) + pos += 2 + elif _is_virtual_statement_begin(tokens, pos): + chunks.append(([token], ChunkType.VIRTUAL_STATEMENT)) pos += 2 + virtual = True else: chunks[-1][0].append(token) pos += 1 @@ -850,13 +894,23 @@ def parse( if expression: expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1]) expressions.append(expression) + elif chunk_type == ChunkType.VIRTUAL_STATEMENT: + sql_chunk = chunk[1:-1] + for expression in parser.parse(sql_chunk, sql): + if expression: + expression.meta["sql"] = expression.sql(dialect=dialect) + expressions.append(virtual_statement(expression)) else: start, *_, end = chunk segment = sql[start.end + 2 : end.start - 1] factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement expression = factory(segment.strip()) expression.meta["sql"] = sql[start.start : end.end + 1] - expressions.append(expression) + expressions.append( + virtual_statement(expression) + if chunk_type == ChunkType.VIRTUAL_JINJA_STATEMENT + else expression + ) return expressions diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 7dca7661e..81afd22b4 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -287,6 +287,7 @@ def depends_on(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[ "expressions_", "pre_statements_", "post_statements_", + "on_virtual_update_", "unique_key", mode="before", check_fields=False, diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index cff43e974..61313cb98 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -135,7 +135,7 @@ def model( **self.kwargs, } - for key in ("pre_statements", "post_statements"): + for key in ("pre_statements", "post_statements", "on_virtual_update"): statements = common_kwargs.get(key) if statements: common_kwargs[key] = [ diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index b00fc97aa..6a793a7a7 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -133,6 +133,9 @@ class _Model(ModelMeta, frozen=True): post_statements_: t.Optional[t.List[exp.Expression]] = Field( default=None, alias="post_statements" ) + on_virtual_update_: t.Optional[t.List[exp.Expression]] = Field( + default=None, alias="on_virtual_update" + ) _expressions_validator = expression_validator @@ -499,10 +502,18 @@ def pre_statements(self) -> t.List[exp.Expression]: def post_statements(self) -> t.List[exp.Expression]: return self.post_statements_ or [] + @property + def on_virtual_update(self) -> t.List[exp.Expression]: + return self.on_virtual_update_ or [] + @property def macro_definitions(self) -> t.List[d.MacroDef]: """All macro definitions from the list of expressions.""" - return [s for s in self.pre_statements + self.post_statements if isinstance(s, d.MacroDef)] + return [ + s + for s in self.pre_statements + self.post_statements + self.on_virtual_update + if isinstance(s, d.MacroDef) + ] def _render_statements( self, @@ -891,7 +902,7 @@ def _data_hash_values(self) -> t.List[str]: data.append(key) data.append(gen(value)) - for statement in (*self.pre_statements, *self.post_statements): + for statement in (*self.pre_statements, *self.post_statements, *self.on_virtual_update): statement_exprs: t.List[exp.Expression] = [] if not isinstance(statement, d.MacroDef): rendered = self._statement_renderer(statement).render() @@ -984,7 +995,7 @@ def _additional_metadata(self) -> t.List[str]: if metadata_only_macros: additional_metadata.append(str(metadata_only_macros)) - for statement in (*self.pre_statements, *self.post_statements): + for statement in (*self.pre_statements, *self.post_statements, *self.on_virtual_update): if self._is_metadata_statement(statement): additional_metadata.append(gen(statement)) @@ -1056,6 +1067,7 @@ class SqlModel(_Model): query: The main query representing the model. pre_statements: The list of SQL statements that precede the model's query. post_statements: The list of SQL statements that follow after the model's query. + on_virtual_update: The list of SQL statements to be executed after virtual update. """ query: t.Union[exp.Query, d.JinjaQuery, d.MacroFunc] @@ -1117,6 +1129,7 @@ def render_definition( result.extend(self.pre_statements) result.append(self.query) result.extend(self.post_statements) + result.extend(self.on_virtual_update) return result @property @@ -1680,7 +1693,7 @@ def load_sql_based_model( rendered_meta = rendered_meta_exprs[0] # Extract the query and any pre/post statements - query_or_seed_insert, pre_statements, post_statements, inline_audits = ( + query_or_seed_insert, pre_statements, post_statements, on_virtual_update, inline_audits = ( _split_sql_model_statements(expressions[1:], path, dialect=dialect) ) @@ -1717,6 +1730,7 @@ def load_sql_based_model( common_kwargs = dict( pre_statements=pre_statements, post_statements=post_statements, + on_virtual_update=on_virtual_update, defaults=defaults, path=path, module_path=module_path, @@ -1968,6 +1982,8 @@ def _create_model( statements.append(kwargs["query"]) if "post_statements" in kwargs: statements.extend(kwargs["post_statements"]) + if "on_virtual_update" in kwargs: + statements.extend(kwargs["on_virtual_update"]) jinja_macro_references, used_variables = extract_macro_references_and_variables( *(gen(e) for e in statements) @@ -2057,6 +2073,7 @@ def _split_sql_model_statements( t.Optional[exp.Expression], t.List[exp.Expression], t.List[exp.Expression], + t.List[exp.Expression], UniqueKeyDict[str, ModelAudit], ]: """Extracts the SELECT query from a sequence of expressions. @@ -2075,6 +2092,7 @@ def _split_sql_model_statements( query_positions = [] sql_statements = [] + on_virtual_update = [] inline_audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("inline_audits") idx = 0 @@ -2086,7 +2104,9 @@ def _split_sql_model_statements( loaded_audit = load_audit([expr, expressions[idx + 1]], dialect=dialect) assert isinstance(loaded_audit, ModelAudit) inline_audits[loaded_audit.name] = loaded_audit - idx += 2 + idx += 1 + elif isinstance(expr, d.VirtualStatement): + on_virtual_update.append(expr.this) else: if ( isinstance(expr, (exp.Query, d.JinjaQuery)) @@ -2098,16 +2118,16 @@ def _split_sql_model_statements( ): query_positions.append((expr, idx)) sql_statements.append(expr) - idx += 1 + idx += 1 if not query_positions: - return None, sql_statements, [], inline_audits + return None, sql_statements, [], on_virtual_update, inline_audits elif len(query_positions) > 1: raise_config_error("Only one SELECT query is allowed per model", path) query, pos = query_positions[0] - return query, sql_statements[:pos], sql_statements[pos + 1 :], inline_audits + return query, sql_statements[:pos], sql_statements[pos + 1 :], on_virtual_update, inline_audits def _resolve_session_properties( diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index ac61936aa..d77da1f18 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -43,6 +43,7 @@ from sqlmesh.schedulers.airflow.client import AirflowClient, BaseAirflowClient from sqlmesh.schedulers.airflow.mwaa_client import MWAAClient from sqlmesh.utils.errors import PlanError, SQLMeshError +from sqlmesh.utils.date import now logger = logging.getLogger(__name__) @@ -309,9 +310,10 @@ def _update_views( completed = False try: + added_snapshots = [snapshots[s.snapshot_id] for s in promotion_result.added] self._promote_snapshots( plan, - [snapshots[s.snapshot_id] for s in promotion_result.added], + added_snapshots, environment.naming_info, deployability_index=deployability_index, on_complete=lambda s: self.console.update_promotion_progress(s, True), @@ -323,6 +325,17 @@ def _update_views( promotion_result.removed_environment_naming_info, on_complete=lambda s: self.console.update_promotion_progress(s, False), ) + + if promoted_snapshots := [ + s for s in added_snapshots if s.is_model and not s.is_symbolic + ]: + self._virtual_statements( + plan, + promoted_snapshots, + snapshots, + deployability_index, + ) + self.state_sync.finalize(environment) completed = True finally: @@ -354,6 +367,24 @@ def _demote_snapshots( target_snapshots, environment_naming_info, on_complete=on_complete ) + def _virtual_statements( + self, + plan: EvaluatablePlan, + target_snapshots: t.Iterable[Snapshot], + snapshots: t.Dict[SnapshotId, Snapshot], + deployability_index: t.Optional[DeployabilityIndex] = None, + ) -> None: + self.snapshot_evaluator._execute_virtual_statements( + target_snapshots, + snapshots, + plan.start, + plan.end, + plan.execution_time or now(), + plan.environment.naming_info, + self.default_catalog, + deployability_index, + ) + def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]) -> None: if not plan.restatements: return diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 3904409e2..6f8990986 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -1591,6 +1591,21 @@ def to_table_mapping( } +def to_view_mapping( + snapshots: t.Iterable[Snapshot], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + dialect: t.Optional[str] = None, +) -> t.Dict[str, str]: + return { + snapshot.name: snapshot.display_name( + environment_naming_info, default_catalog=default_catalog, dialect=dialect + ) + for snapshot in snapshots + if snapshot.is_model + } + + def has_paused_forward_only( targets: t.Iterable[SnapshotIdLike], snapshots: t.Union[t.List[Snapshot], t.Dict[SnapshotId, Snapshot]], diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 52c0e63f1..1b3258f13 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -60,6 +60,7 @@ SnapshotInfoLike, SnapshotTableCleanupTask, ) +from sqlmesh.core.snapshot.definition import to_view_mapping from sqlmesh.utils import random_id from sqlmesh.utils.concurrency import ( concurrent_apply_to_snapshots, @@ -985,6 +986,47 @@ def _get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.") return self.adapter + def _execute_virtual_statements( + self, + target_snapshots: t.Iterable[Snapshot], + snapshots: t.Dict[SnapshotId, Snapshot], + start: TimeLike, + end: TimeLike, + execution_time: TimeLike, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + deployability_index: t.Optional[DeployabilityIndex] = None, + ) -> None: + """ + Executes virtual statements for the provided target snapshots. + """ + + # Resolving the tables to their qualified view names. + table_mapping = to_view_mapping( + snapshots.values(), + environment_naming_info, + default_catalog=default_catalog, + dialect=self.adapter.dialect, + ) + + for snapshot in target_snapshots: + adapter = self._get_adapter(snapshot.model_gateway) + snapshot_deps = {snapshots[p_sid].name: snapshots[p_sid] for p_sid in snapshot.parents} + snapshot_deps[snapshot.name] = snapshot + if virtual_statements := snapshot.model.on_virtual_update: + adapter.execute( + snapshot.model._render_statements( + virtual_statements, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + deployability_index=deployability_index, + engine_adapter=adapter, + table_mapping=table_mapping, + ) + ) + def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> EvaluationStrategy: klass: t.Type diff --git a/tests/core/test_model.py b/tests/core/test_model.py index e8163ab8c..15beab5fc 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -6266,3 +6266,174 @@ def assert_metadata_only(): model = load_sql_based_model(expressions, signal_definitions=signal.get_registry()) model.signals.clear() assert_metadata_only() + + +def test_model_on_virtual_update(make_snapshot: t.Callable): + # Macro to test resolution within virtual statement + @macro() + def resolve_parent_name(evaluator, name): + return evaluator.resolve_table(name.name) + + expressions = d.parse( + """ + MODEL ( + name demo_db.table, + owner owner_name, + ); + + SELECT id from parent; + + ON_VIRTUAL_UPDATE_BEGIN; + + CREATE OR REPLACE VIEW test_view FROM demo_db.table; + GRANT SELECT ON VIEW @this_model TO ROLE owner_name; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; + JINJA_END; + GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name; + @resolve_parent_name('parent'); + GRANT SELECT ON VIEW demo_db.table /* sqlglot.meta replace=false */ TO ROLE admin; + + ON_VIRTUAL_UPDATE_END; + + """ + ) + + parent_expressions = d.parse( + """ + MODEL ( + name parent, + ); + + SELECT 1 from id; + + ON_VIRTUAL_UPDATE_BEGIN; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; + JINJA_END; + ON_VIRTUAL_UPDATE_END; + + """ + ) + + model = load_sql_based_model(expressions) + parent = load_sql_based_model(parent_expressions) + + parent_snapshot = make_snapshot(parent) + parent_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + version = parent_snapshot.version + + expected_virtual_statements = [ + *d.parse("CREATE OR REPLACE VIEW test_view FROM demo_db.table;"), + *d.parse("GRANT SELECT ON VIEW @this_model TO ROLE owner_name;"), + *d.parse( + "JINJA_STATEMENT_BEGIN; GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; JINJA_END;" + ), + *d.parse( + "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name;" + ), + *d.parse("@resolve_parent_name('parent')"), + *d.parse( + "GRANT SELECT ON VIEW demo_db.table /* sqlglot.meta replace=false */ TO ROLE admin;" + ), + ] + + assert model.on_virtual_update == expected_virtual_statements + + assert parent.on_virtual_update == [ + *d.parse( + "JINJA_STATEMENT_BEGIN; GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; JINJA_END;" + ) + ] + + table_mapping = {'"demo_db"."table"': "demo_db__dev.table"} + snapshots = {'"parent"': parent_snapshot} + + rendered_statements = model._render_statements( + model.on_virtual_update, snapshots=snapshots, table_mapping=table_mapping + ) + + assert len(rendered_statements) == 6 + assert ( + rendered_statements[0].sql() + == 'CREATE OR REPLACE VIEW "test_view" AS SELECT * FROM "demo_db__dev"."table" AS "table" /* demo_db.table */' + ) + assert ( + rendered_statements[1].sql() + == 'GRANT SELECT ON VIEW "demo_db__dev"."table" /* demo_db.table */ TO ROLE "owner_name"' + ) + assert ( + rendered_statements[3].sql() + == "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name" + ) + assert rendered_statements[4].sql() == f'"sqlmesh__default"."parent__{version}"' + + # When replace=false the table should remain as is + assert ( + rendered_statements[5].sql() + == 'GRANT SELECT ON VIEW "demo_db"."table" /* sqlglot.meta replace=false */ TO ROLE "admin"' + ) + + rendered_parent_statements = model._render_statements( + parent.on_virtual_update, snapshots=snapshots, table_mapping=table_mapping + ) + assert ( + rendered_statements[2].sql() + == rendered_parent_statements[0].sql() + == 'GRANT SELECT ON VIEW "demo_db__dev"."table" /* demo_db.table */ TO ROLE "admin"' + ) + + +def test_python_model_on_virtual_update(): + macros = """ + {% macro index_name(v) %}{{ v }}{% endmacro %} + """ + + jinja_macros = JinjaMacroRegistry() + jinja_macros.add_macros(MacroExtractor().extract(macros)) + + @model( + "db.test_model", + kind="full", + columns={"id": "string", "name": "string"}, + on_virtual_update=[ + "JINJA_STATEMENT_BEGIN;\nCREATE INDEX {{index_name('id_index')}} ON db.test_model(id);\nJINJA_END;", + parse_one("GRANT SELECT ON VIEW @this_model TO ROLE dev_role;"), + "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE db TO ROLE dev_role;", + ], + ) + def model_with_virtual_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), path=Path("."), dialect="duckdb", jinja_macros=jinja_macros + ) + + assert len(jinja_macros.root_macros) == 1 + assert len(python_model.jinja_macros.root_macros) == 1 + assert "index_name" in python_model.jinja_macros.root_macros + assert len(python_model.on_virtual_update) == 3 + + rendered_statements = python_model._render_statements( + python_model.on_virtual_update, table_mapping={'"db"."test_model"': "db.test_model"} + ) + + assert ( + rendered_statements[0].sql() + == 'CREATE INDEX "id_index" ON "db"."test_model" /* db.test_model */("id" NULLS LAST)' + ) + assert ( + rendered_statements[1].sql() + == 'GRANT SELECT ON VIEW "db"."test_model" /* db.test_model */ TO ROLE "dev_role"' + ) + assert ( + rendered_statements[2].sql() + == "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE db TO ROLE dev_role" + )