Skip to content

Commit

Permalink
DEFAULT VALUES and VALUES (DEFAULT)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Jun 20, 2024
1 parent ea231e0 commit cf7eb2a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 21 deletions.
89 changes: 69 additions & 20 deletions edb/pgsql/resolver/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,9 @@ def resolve_InsertStmt(
((c.name, c.span) for c in stmt.cols) if stmt.cols else None,
)

# compile value that is to be inserted normally
val_rel: pgast.BaseRelation
if stmt.select_stmt:
val_rel = stmt.select_stmt
else:
# INSERT INTO x DEFAULT VALUES
val_rel = pgast.SelectStmt(values=[])

# edgeql compiler will provide default values
# (and complain about missing ones)
expected_columns = []
val_rel, val_table = dispatch.resolve_relation(val_rel, ctx=ctx)

if len(expected_columns) != len(val_table.columns):
col_names = ', '.join(c.name for c in expected_columns)
raise errors.QueryError(
f'INSERT expected {len(expected_columns)} columns, '
f'but got {len(val_table.columns)} (expecting {col_names})',
span=val_rel.span,
)
val_rel, val_table = compile_insert_value(
stmt.select_stmt, expected_columns, ctx
)

# if we are sure that we are inserting a single row,
# we can skip for loops and the iterator, so we generate better SQL
Expand Down Expand Up @@ -378,6 +361,71 @@ def resolve_InsertStmt(
return result_query, result_table


def compile_insert_value(
value_query: Optional[pgast.Query],
expected_columns: List[context.Column],
ctx: context.ResolverContextLevel,
) -> Tuple[pgast.BaseRelation, context.Table]:
# VALUES (DEFAULT)
if isinstance(value_query, pgast.SelectStmt) and value_query.values:
# find DEFAULT keywords in VALUES

def is_default(e: pgast.BaseExpr) -> bool:
return isinstance(e, pgast.Keyword) and e.name == 'DEFAULT'

default_columns = set()
for row in value_query.values:
assert isinstance(row, pgast.ImplicitRowExpr)

for to_remove, col in enumerate(row.args):
if is_default(col):
default_columns.add(to_remove)

# remove DEFAULT keywords and expected columns,
# so EdgeQL insert will not get those columns, which will use the
# property defaults.
for to_remove in sorted(default_columns, reverse=True):
del expected_columns[to_remove]

for r_index, row in enumerate(value_query.values):
assert isinstance(row, pgast.ImplicitRowExpr)

if not is_default(row.args[to_remove]):
raise errors.QueryError(
'DEFAULT keyword is supported only when '
'used for a column in all rows',
span=value_query.span,
pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED,
)
cols = list(row.args)
del cols[to_remove]
value_query.values[r_index] = row.replace(args=cols)

# INSERT INTO x DEFAULT VALUES
val_rel: pgast.BaseRelation
if value_query:
val_rel = value_query
else:
val_rel = pgast.SelectStmt(values=[])

# edgeql compiler will provide default values
# (and complain about missing ones)
expected_columns = []

# compile value that is to be inserted
val_rel, val_table = dispatch.resolve_relation(val_rel, ctx=ctx)

if len(expected_columns) != len(val_table.columns):
col_names = ', '.join(c.name for c in expected_columns)
raise errors.QueryError(
f'INSERT expected {len(expected_columns)} columns, '
f'but got {len(val_table.columns)} (expecting {col_names})',
span=val_rel.span,
)

return val_rel, val_table


def returning_rows(
returning_list: List[pgast.ResTarget],
subject_pointers: List[Tuple[str, str]],
Expand Down Expand Up @@ -463,6 +511,7 @@ def construct_insert_element_for_ptr(
if is_link:
# add .id for links, which will figure out that it has uuid type.
# This will make type cast to the object type into "find_by_id".
assert isinstance(ptr_ql, qlast.Path)
ptr_ql.steps.append(qlast.Ptr(name='id'))

ptr_target = ptr.get_target(ctx.schema)
Expand Down
29 changes: 28 additions & 1 deletion tests/test_sql_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# limitations under the License.
#


from edb.testbase import server as tb
from edb.tools import test

Expand Down Expand Up @@ -280,3 +279,31 @@ async def test_sql_dml_insert_13(self):
VALUES ('Briefing', FALSE)
'''
)

async def test_sql_dml_insert_14(self):
# default values

await self.scon.execute(
'''
INSERT INTO "Document" DEFAULT VALUES;
'''
)

await self.scon.execute(
'''
INSERT INTO "Document" (id, title) VALUES (DEFAULT, 'Report');
'''
)
res = await self.squery_values('SELECT title FROM "Document"')
self.assert_data_shape(res, tb.bag([[None], ['Report (new)']]))

with self.assertRaisesRegex(
asyncpg.FeatureNotSupportedError,
'DEFAULT keyword is supported only when '
'used for a column in all rows',
):
await self.scon.execute(
'''
INSERT INTO "Document" (title) VALUES ('Report'), (DEFAULT);
'''
)

0 comments on commit cf7eb2a

Please sign in to comment.