Skip to content

Commit

Permalink
finish the resolver part of locking clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Nov 21, 2024
1 parent f2c6e32 commit fab69b5
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 77 deletions.
8 changes: 4 additions & 4 deletions edb/pgsql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,10 +913,10 @@ class SortBy(ImmutableBase):


class LockClauseStrength(enum.StrEnum):
FORKEYSHARE = "KEY SHARE"
FORSHARE = "SHARE"
FORNOKEYUPDATE = "NO KEY SHARE"
FORUPDATE = "UPDATE"
UPDATE = "UPDATE"
NO_KEY_UPDATE = "NO KEY UPDATE"
SHARE = "SHARE"
KEY_SHARE = "KEY SHARE"


class LockWaitPolicy(enum.StrEnum):
Expand Down
35 changes: 20 additions & 15 deletions edb/pgsql/parser/ast_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def _list(
) -> List[U]:
if unwrap is not None:
return [
mapper(builder(_unwrap(n, unwrap), ctx))
for n in node.get(name, [])
mapper(builder(_unwrap(n, unwrap), ctx)) for n in node.get(name, [])
]
else:
return [mapper(builder(n, ctx)) for n in node.get(name, [])]
Expand All @@ -118,7 +117,8 @@ def _maybe_list(
) -> Optional[List[U]]:
return (
_list(node, ctx, name, builder, mapper, unwrap=unwrap)
if name in node else None
if name in node
else None
)


Expand Down Expand Up @@ -287,7 +287,8 @@ def _build_select_stmt(n: Node, c: Context) -> pgast.SelectStmt:
limit_offset=_maybe(n, c, "limitOffset", _build_base_expr),
limit_count=_maybe(n, c, "limitCount", _build_base_expr),
locking_clause=_maybe_list(
n, c, "lockingClause", _build_locking_clause),
n, c, "lockingClause", _build_locking_clause
),
op=op,
all=n["all"] if "all" in n else False,
larg=_maybe(n, c, "larg", _build_select_stmt),
Expand Down Expand Up @@ -893,9 +894,7 @@ def _build_const(n: Node, c: Context) -> pgast.BaseConstant:
return pgast.NumericConstant(val=_unwrap_float(n), span=span)

if "String" in n or "sval" in n:
return pgast.StringConstant(
val=_unwrap_string(n), span=span
)
return pgast.StringConstant(val=_unwrap_string(n), span=span)
if "BitString" in n or "bsval" in n:
n = _unwrap(n, 'BitString')
n = _unwrap(n, 'str')
Expand Down Expand Up @@ -1179,24 +1178,30 @@ def _build_sort_by(n: Node, c: Context) -> pgast.SortBy:
def _build_locking_clause(n: Node, c: Context) -> pgast.LockingClause:
n = _unwrap(n, "LockingClause")

lcs = n["strength"].removeprefix("LCS_")
match n["strength"]:
case "LCS_FORUPDATE":
strength = pgast.LockClauseStrength.UPDATE
case "LCS_FORNOKEYUPDATE":
strength = pgast.LockClauseStrength.NO_KEY_UPDATE
case "LCS_FORSHARE":
strength = pgast.LockClauseStrength.SHARE
case "LCS_FORKEYSHARE":
strength = pgast.LockClauseStrength.KEY_SHARE
case lcs:
raise PSqlUnsupportedError(n, f"FOR lock strength: {lcs}")

strength = getattr(pgast.LockClauseStrength, lcs, None)
if strength is None:
raise PSqlUnsupportedError(f"unrecognized FOR lock strength: {lcs}")

print(n)
if pol := n.get("waitPolicy"):
wait_policy = getattr(pgast.LockWaitPolicy, pol, None)
if wait_policy is None:
raise PSqlUnsupportedError(f"unrecognized FOR wait policy: {pol}")
raise PSqlUnsupportedError(n, f"FOR wait policy: {pol}")
else:
wait_policy = None

return pgast.LockingClause(
strength=strength,
locked_rels=_maybe_list(
n, c, "lockedRels", _build_rel_range_var, unwrap="RangeVar"),
n, c, "lockedRels", _build_rel_range_var, unwrap="RangeVar"
),
wait_policy=wait_policy,
)

Expand Down
5 changes: 5 additions & 0 deletions edb/pgsql/resolver/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ class Table:
# than columns of input rel vars (tables).
precedence: int = 0

# True when this relation is compiled to a direct reference to the
# underlying table, without any views or CTEs.
# Is the condition for usage of locking clauses.
is_direct_relation: bool = False

def __str__(self) -> str:
columns = ', '.join(str(c) for c in self.columns)
alias = f'{self.alias} = ' if self.alias else ''
Expand Down
36 changes: 21 additions & 15 deletions edb/pgsql/resolver/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import uuid

from edb import errors
from edb.common.typeutils import not_none

from edb.pgsql import ast as pgast
from edb.pgsql import common
Expand Down Expand Up @@ -509,24 +508,31 @@ def resolve_LockingClause(
*,
ctx: Context,
) -> pgast.LockingClause:
lrels = expr.locked_rels
if lrels is not None:
resolved_lrels = []
for rvar in lrels:
ltable = _lookup_table(not_none(rvar.relation.name), ctx=ctx)
resolved_lrels.append(
pgast.RelRangeVar(
relation=pgast.Relation(
name=ltable.reference_as,
)
)
)

tables: List[context.Table] = []
if expr.locked_rels is not None:
for rvar in expr.locked_rels:
assert rvar.relation.name
table = _lookup_table(rvar.relation.name, ctx=ctx)
tables.append(table)
else:
resolved_lrels = None
tables.extend(ctx.scope.tables)

# validate that the locking clause can be used on these tables
for table in tables:
if table.schema_id and not table.is_direct_relation:
raise errors.QueryError(
f'locking clause not supported: `{table.name or table.alias}` '
'must not have child types or access policies',
pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED,
)

return pgast.LockingClause(
strength=expr.strength,
locked_rels=resolved_lrels,
locked_rels=[
pgast.RelRangeVar(relation=pgast.Relation(name=table.reference_as))
for table in tables
],
wait_policy=expr.wait_policy,
)

Expand Down
105 changes: 63 additions & 42 deletions edb/pgsql/resolver/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,22 +147,22 @@ def resolve_SelectStmt(

# order by can refer to columns in SELECT projection, so we need to add
# table.columns into scope
ctx.scope.tables.append(
context.Table(
columns=[
context.Column(
name=c.name,
kind=context.ColumnByName(reference_as=c.name),
)
for c, target in zip(table.columns, stmt.target_list)
if target.name
and (
not isinstance(target.val, pgast.ColumnRef)
or target.val.name[-1] != target.name
)
]
)
projected_table = context.Table(
columns=[
context.Column(
name=c.name,
kind=context.ColumnByName(reference_as=c.name),
)
for c, target in zip(table.columns, stmt.target_list)
if target.name
and (
not isinstance(target.val, pgast.ColumnRef)
or target.val.name[-1] != target.name
)
]
)
if len(projected_table.columns) > 0:
ctx.scope.tables.append(projected_table)

sort_clause = dispatch.resolve_opt_list(stmt.sort_clause, ctx=ctx)
limit_offset = dispatch.resolve_opt(stmt.limit_offset, ctx=ctx)
Expand Down Expand Up @@ -262,7 +262,9 @@ def resolve_relation(
for n, _type, _ver_since in preset_tables[0][relation.name]
]
cols.extend(_construct_system_columns())
table = context.Table(name=relation.name, columns=cols)
table = context.Table(
name=relation.name, columns=cols, is_direct_relation=True
)
rel = pgast.Relation(name=relation.name, schemaname=preset_tables[1])

return rel, table
Expand Down Expand Up @@ -330,7 +332,7 @@ def public_to_default(s: str) -> str:
if card.is_multi():
continue

columns.append(_construct_column(p, ctx, include_inherited))
columns.append(_construct_column(p, ctx))
else:
for c in ['source', 'target']:
columns.append(
Expand All @@ -354,31 +356,62 @@ def column_order_key(c: context.Column) -> Tuple[int, str]:

table.columns.extend(_construct_system_columns())

if ctx.options.apply_access_policies:
if ctx.options.apply_access_policies and _has_access_policies(obj, ctx):
if isinstance(obj, s_objtypes.ObjectType):
relation = _compile_read_of_obj_table(
rel = _compile_read_of_obj_table(
obj, include_inherited, table, ctx
)
else:
# link and multi-property tables cannot have access policies,
# so we allow access to base table directly
relation = _relation_of_table(obj, ctx)
# TODO: implement access policy filtering for link and
# multi-property tables
rel = _relation_of_table(obj, table, ctx)
else:
if include_inherited:
relation = _relation_of_inheritance_cte(obj, ctx)
if include_inherited and _has_sub_types(obj, ctx):
rel = _relation_of_inheritance_cte(obj, ctx)
else:
relation = _relation_of_table(obj, ctx)
rel = _relation_of_table(obj, table, ctx)

return relation, table
return rel, table


def _relation_of_table(
def _has_access_policies(
obj: s_sources.Source | s_properties.Property, ctx: Context
):
if isinstance(obj, s_pointers.Pointer):
return False
assert isinstance(obj, s_objtypes.ObjectType)

policies = obj.get_access_policies(ctx.schema)
return len(policies) > 0


def _has_sub_types(obj: s_sources.Source | s_properties.Property, ctx: Context):
if isinstance(obj, s_pointers.Pointer):
return False
assert isinstance(obj, s_objtypes.ObjectType)

return len(obj.children(ctx.schema)) > 0


def _relation_of_table(
obj: s_sources.Source | s_properties.Property,
table: context.Table,
ctx: Context,
) -> pgast.Relation:
schemaname, dbname = pgcommon.get_backend_name(
ctx.schema, obj, aspect='table', catenate=False
)
return pgast.Relation(name=dbname, schemaname=schemaname)
relation = pgast.Relation(name=dbname, schemaname=schemaname)

table.is_direct_relation = True
# When referencing actual tables, we need to statically provide __type__,
# since this column does not exist in the database.
for col in table.columns:
if col.name == '__type__':
col.kind = context.ColumnStaticVal(val=obj.id)
break

return relation


def _relation_of_inheritance_cte(
Expand Down Expand Up @@ -446,9 +479,7 @@ def _lookup_pointer_table(
raise NotImplementedError()


def _construct_column(
p: s_pointers.Pointer, ctx: Context, include_inherited: bool
) -> context.Column:
def _construct_column(p: s_pointers.Pointer, ctx: Context) -> context.Column:
short_name = p.get_shortname(ctx.schema)

col_name: str
Expand Down Expand Up @@ -476,17 +507,7 @@ def _construct_column(
kind = context.ColumnComputable(pointer=p)
elif short_name.name == '__type__':
col_name = '__type__'

if not include_inherited:
# When using FROM ONLY, we will be referencing actual tables
# and not inheritance views. Actual tables don't contain
# __type__ column, which means that we have to provide value
# in some other way. Fortunately, it is a constant value, so we
# can compute it statically.
source_id = p.get_source_type(ctx.schema).get_id(ctx.schema)
kind = context.ColumnStaticVal(val=source_id)
else:
kind = context.ColumnByName(reference_as='__type__')
kind = context.ColumnByName(reference_as='__type__')
else:
col_name = short_name.name + '_id'
_, dbname = pgcommon.get_backend_name(ctx.schema, p, catenate=False)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,3 +1283,45 @@ def test_sql_parse_copy_06(self):
"""
COPY country TO PROGRAM 'gzip > /usr1/proj/bray/sql/country_data.gz'
"""

def test_sql_parse_table(self):
"""
TABLE hello_world
% OK %
SELECT * FROM hello_world
"""

def test_sql_parse_select_locking_00(self):
"""
SELECT id FROM a FOR UPDATE
"""

def test_sql_parse_select_locking_01(self):
"""
SELECT id FROM a FOR NO KEY UPDATE
"""

def test_sql_parse_select_locking_02(self):
"""
SELECT id FROM a FOR SHARE
"""

def test_sql_parse_select_locking_03(self):
"""
SELECT id FROM a FOR KEY SHARE
"""

def test_sql_parse_select_locking_04(self):
"""
SELECT id FROM a FOR UPDATE NOWAIT
"""

def test_sql_parse_select_locking_05(self):
"""
SELECT id FROM a FOR UPDATE SKIP LOCKED
"""

def test_sql_parse_select_locking_06(self):
"""
SELECT id FROM a FOR UPDATE OF b
"""
Loading

0 comments on commit fab69b5

Please sign in to comment.