From fb32c83717e43d1e188395c0879d522927d5fd23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Thu, 13 Jun 2024 22:07:22 +0200 Subject: [PATCH] Refactor SQL resolver dispatch (#7457) --- edb/pgsql/resolver/dispatch.py | 36 +++++++++++++++++---------------- edb/pgsql/resolver/range_var.py | 6 ++++-- edb/pgsql/resolver/relation.py | 6 +++--- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/edb/pgsql/resolver/dispatch.py b/edb/pgsql/resolver/dispatch.py index 2e606e32e42..f5c73cc94cc 100644 --- a/edb/pgsql/resolver/dispatch.py +++ b/edb/pgsql/resolver/dispatch.py @@ -32,22 +32,22 @@ @functools.singledispatch def _resolve( - ir: pgast.Base, *, ctx: context.ResolverContextLevel + expr: pgast.Base, *, ctx: context.ResolverContextLevel ) -> pgast.Base: - raise ValueError(f'no SQL resolve handler for {ir.__class__}') + raise ValueError(f'no SQL resolve handler for {expr.__class__}') -def resolve(ir: Base_T, *, ctx: context.ResolverContextLevel) -> Base_T: - res = _resolve(ir, ctx=ctx) - return typing.cast(Base_T, res.replace(span=ir.span)) +def resolve(expr: Base_T, *, ctx: context.ResolverContextLevel) -> Base_T: + res = _resolve(expr, ctx=ctx) + return typing.cast(Base_T, res.replace(span=expr.span)) def resolve_opt( - ir: typing.Optional[Base_T], *, ctx: context.ResolverContextLevel + node: typing.Optional[Base_T], *, ctx: context.ResolverContextLevel ) -> typing.Optional[Base_T]: - if not ir: + if not node: return None - return resolve(ir, ctx=ctx) + return resolve(node, ctx=ctx) def resolve_list( @@ -66,23 +66,25 @@ def resolve_opt_list( return resolve_list(exprs, ctx=ctx) -@functools.singledispatch -def _resolve_relation( - ir: pgast.BaseRelation, *, ctx: context.ResolverContextLevel +def resolve_relation( + rel: pgast.BaseRelation, *, ctx: context.ResolverContextLevel ) -> typing.Tuple[pgast.BaseRelation, context.Table]: - raise ValueError(f'no SQL resolve handler for {ir.__class__}') + rel, tab = _resolve_relation(rel, ctx=ctx) + return rel.replace(span=rel.span), tab -def resolve_relation( - ir: BaseRelation_T, *, ctx: context.ResolverContextLevel -) -> typing.Tuple[BaseRelation_T, context.Table]: - res, tab = _resolve_relation(ir, ctx=ctx) - return typing.cast(BaseRelation_T, res.replace(span=ir.span)), tab +@functools.singledispatch +def _resolve_relation( + rel: pgast.BaseRelation, *, ctx: context.ResolverContextLevel +) -> typing.Tuple[pgast.BaseRelation, context.Table]: + raise ValueError(f'no SQL resolve handler for {rel.__class__}') @_resolve.register def _resolve_BaseRelation( rel: pgast.BaseRelation, *, ctx: context.ResolverContextLevel ) -> pgast.BaseRelation: + # use _resolve_BaseRelation in normal _resolve dispatch + rel, _ = resolve_relation(rel, ctx=ctx) return rel diff --git a/edb/pgsql/resolver/range_var.py b/edb/pgsql/resolver/range_var.py index e040064f95b..74f0fdc3a23 100644 --- a/edb/pgsql/resolver/range_var.py +++ b/edb/pgsql/resolver/range_var.py @@ -151,7 +151,9 @@ def _resolve_RangeSubselect( ) node = pgast.RangeSubselect( - subquery=subquery, alias=alias, lateral=range_var.lateral + subquery=cast(pgast.Query, subquery), + alias=alias, + lateral=range_var.lateral, ) return node, result @@ -263,7 +265,7 @@ def resolve_CommonTableExpr( name=cte.name, span=cte.span, aliascolnames=reference_as, - query=query, + query=cast(pgast.Query, query), recursive=cte.recursive, materialized=cte.materialized, ) diff --git a/edb/pgsql/resolver/relation.py b/edb/pgsql/resolver/relation.py index fcddb875b4c..0188646dbb9 100644 --- a/edb/pgsql/resolver/relation.py +++ b/edb/pgsql/resolver/relation.py @@ -19,7 +19,7 @@ """SQL resolver that compiles public SQL to internal SQL which is executable in our internal Postgres instance.""" -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, cast from edb import errors from edb.server.pgcon import errors as pgerror @@ -87,8 +87,8 @@ def resolve_SelectStmt( ) relation = pgast.SelectStmt( - larg=larg, - rarg=rarg, + larg=cast(pgast.Query, larg), + rarg=cast(pgast.Query, rarg), op=stmt.op, all=stmt.all, )