Skip to content

Commit

Permalink
Refactor SQL resolver dispatch (#7457)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Jun 13, 2024
1 parent a2b0c3b commit fb32c83
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 22 deletions.
36 changes: 19 additions & 17 deletions edb/pgsql/resolver/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,22 @@

@functools.singledispatch
def _resolve(
ir: pgast.Base, *, ctx: context.ResolverContextLevel
expr: pgast.Base, *, ctx: context.ResolverContextLevel
) -> pgast.Base:
raise ValueError(f'no SQL resolve handler for {ir.__class__}')
raise ValueError(f'no SQL resolve handler for {expr.__class__}')


def resolve(ir: Base_T, *, ctx: context.ResolverContextLevel) -> Base_T:
res = _resolve(ir, ctx=ctx)
return typing.cast(Base_T, res.replace(span=ir.span))
def resolve(expr: Base_T, *, ctx: context.ResolverContextLevel) -> Base_T:
res = _resolve(expr, ctx=ctx)
return typing.cast(Base_T, res.replace(span=expr.span))


def resolve_opt(
ir: typing.Optional[Base_T], *, ctx: context.ResolverContextLevel
node: typing.Optional[Base_T], *, ctx: context.ResolverContextLevel
) -> typing.Optional[Base_T]:
if not ir:
if not node:
return None
return resolve(ir, ctx=ctx)
return resolve(node, ctx=ctx)


def resolve_list(
Expand All @@ -66,23 +66,25 @@ def resolve_opt_list(
return resolve_list(exprs, ctx=ctx)


@functools.singledispatch
def _resolve_relation(
ir: pgast.BaseRelation, *, ctx: context.ResolverContextLevel
def resolve_relation(
rel: pgast.BaseRelation, *, ctx: context.ResolverContextLevel
) -> typing.Tuple[pgast.BaseRelation, context.Table]:
raise ValueError(f'no SQL resolve handler for {ir.__class__}')
rel, tab = _resolve_relation(rel, ctx=ctx)
return rel.replace(span=rel.span), tab


def resolve_relation(
ir: BaseRelation_T, *, ctx: context.ResolverContextLevel
) -> typing.Tuple[BaseRelation_T, context.Table]:
res, tab = _resolve_relation(ir, ctx=ctx)
return typing.cast(BaseRelation_T, res.replace(span=ir.span)), tab
@functools.singledispatch
def _resolve_relation(
rel: pgast.BaseRelation, *, ctx: context.ResolverContextLevel
) -> typing.Tuple[pgast.BaseRelation, context.Table]:
raise ValueError(f'no SQL resolve handler for {rel.__class__}')


@_resolve.register
def _resolve_BaseRelation(
rel: pgast.BaseRelation, *, ctx: context.ResolverContextLevel
) -> pgast.BaseRelation:
# use _resolve_BaseRelation in normal _resolve dispatch

rel, _ = resolve_relation(rel, ctx=ctx)
return rel
6 changes: 4 additions & 2 deletions edb/pgsql/resolver/range_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def _resolve_RangeSubselect(
)

node = pgast.RangeSubselect(
subquery=subquery, alias=alias, lateral=range_var.lateral
subquery=cast(pgast.Query, subquery),
alias=alias,
lateral=range_var.lateral,
)
return node, result

Expand Down Expand Up @@ -263,7 +265,7 @@ def resolve_CommonTableExpr(
name=cte.name,
span=cte.span,
aliascolnames=reference_as,
query=query,
query=cast(pgast.Query, query),
recursive=cte.recursive,
materialized=cte.materialized,
)
Expand Down
6 changes: 3 additions & 3 deletions edb/pgsql/resolver/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""SQL resolver that compiles public SQL to internal SQL which is executable
in our internal Postgres instance."""

from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, cast

from edb import errors
from edb.server.pgcon import errors as pgerror
Expand Down Expand Up @@ -87,8 +87,8 @@ def resolve_SelectStmt(
)

relation = pgast.SelectStmt(
larg=larg,
rarg=rarg,
larg=cast(pgast.Query, larg),
rarg=cast(pgast.Query, rarg),
op=stmt.op,
all=stmt.all,
)
Expand Down

0 comments on commit fb32c83

Please sign in to comment.