diff --git a/edb/edgeql/ast.py b/edb/edgeql/ast.py index b8bfd564ac0..cfff91f7fab 100644 --- a/edb/edgeql/ast.py +++ b/edb/edgeql/ast.py @@ -605,6 +605,8 @@ class DeleteQuery(PipelinedQuery): class ForQuery(Query): from_desugaring: bool = False + has_union: bool = True # whether UNION was used in the syntax + optional: bool = False iterator: Expr iterator_alias: str diff --git a/edb/edgeql/codegen.py b/edb/edgeql/codegen.py index 906bcdd2169..4879d36e8f0 100644 --- a/edb/edgeql/codegen.py +++ b/edb/edgeql/codegen.py @@ -119,15 +119,20 @@ def _kw_case(self, *kws: str) -> str: def _write_keywords(self, *kws: str) -> None: self.write(self._kw_case(*kws)) - def _needs_parentheses(self, node) -> bool: # type: ignore + def _needs_parentheses(self, node: Any) -> bool: # The "parent" attribute is set by calling `_fix_parent_links` # before traversing the AST. Since it's not an attribute that # can be inferred by static typing we ignore typing for this # function. + parent: Optional[qlast.Base] = node._parent return ( - node._parent is not None and ( - not isinstance(node._parent, qlast.Base) - or not isinstance(node._parent, qlast.DDL) + parent is not None + and not isinstance(parent, qlast.DDL) + # Non-union FOR bodies can't have parens + and not ( + isinstance(parent, qlast.ForQuery) + and not parent.has_union + and parent.result is node ) ) @@ -303,7 +308,7 @@ def visit_SelectQuery(self, node: qlast.SelectQuery) -> None: self.write(')') def visit_ForQuery(self, node: qlast.ForQuery) -> None: - # need to parenthesise when GROUP appears as an expression + # need to parenthesize when FOR appears as an expression parenthesise = self._needs_parentheses(node) if parenthesise: @@ -317,10 +322,13 @@ def visit_ForQuery(self, node: qlast.ForQuery) -> None: self.visit(node.iterator) # guarantee an newline here self.new_lines = 1 - self._write_keywords('UNION ') - self._block_ws(1) - self.visit(node.result) - self.indentation -= 1 + if node.has_union: + self._write_keywords('UNION ') + self._block_ws(1) + self.visit(node.result) + self.indentation -= 1 + else: + self.visit(node.result) if parenthesise: self.write(')') diff --git a/edb/edgeql/parser/grammar/expressions.py b/edb/edgeql/parser/grammar/expressions.py index 67766dd7b60..d20314c271e 100644 --- a/edb/edgeql/parser/grammar/expressions.py +++ b/edb/edgeql/parser/grammar/expressions.py @@ -168,14 +168,25 @@ def reduce_empty(self, *kids): class SimpleFor(Nonterm): - def reduce_For(self, *kids): - r"%reduce FOR OptionalOptional Identifier IN AtomicExpr \ - UNION Expr" + def reduce_ForIn(self, *kids): + r"%reduce FOR OptionalOptional Identifier IN AtomicExpr UNION Expr" + _, optional, iterator_alias, _, iterator, _, body = kids self.val = qlast.ForQuery( - optional=kids[1].val, - iterator_alias=kids[2].val, - iterator=kids[4].val, - result=kids[6].val, + optional=optional.val, + iterator_alias=iterator_alias.val, + iterator=iterator.val, + result=body.val, + ) + + def reduce_ForInStmt(self, *kids): + r"%reduce FOR OptionalOptional Identifier IN AtomicExpr ExprStmt" + _, optional, iterator_alias, _, iterator, body = kids + self.val = qlast.ForQuery( + has_union=False, + optional=optional.val, + iterator_alias=iterator_alias.val, + iterator=iterator.val, + result=body.val, ) diff --git a/tests/test_edgeql_for.py b/tests/test_edgeql_for.py index f376d96424c..9915b2ebf14 100644 --- a/tests/test_edgeql_for.py +++ b/tests/test_edgeql_for.py @@ -341,16 +341,12 @@ async def test_edgeql_for_in_computable_02b(self): select_deck := (( WITH cards := ( FOR letter IN {'I', 'B'} - UNION ( - FOR copy IN {'1', '2'} - UNION ( - SELECT User.deck { - name, - letter := letter ++ copy - } - FILTER User.deck.name[0] = letter - ) - ) + FOR copy IN {'1', '2'} + SELECT User.deck { + name, + letter := letter ++ copy + } + FILTER User.deck.name[0] = letter ) SELECT cards ORDER BY .name THEN .letter ),) diff --git a/tests/test_edgeql_insert.py b/tests/test_edgeql_insert.py index 45ed2e186e6..13c7145569f 100644 --- a/tests/test_edgeql_insert.py +++ b/tests/test_edgeql_insert.py @@ -1273,18 +1273,18 @@ async def test_edgeql_insert_policy_cast(self): async def test_edgeql_insert_for_01(self): await self.con.execute(r''' FOR x IN {3, 5, 7, 2} - UNION (INSERT InsertTest { + INSERT InsertTest { name := 'insert for 1', l2 := x, - }); + }; FOR Q IN (SELECT InsertTest{foo := 'foo' ++ InsertTest.l2} FILTER .name = 'insert for 1') - UNION (INSERT InsertTest { + INSERT InsertTest { name := 'insert for 1', l2 := 35 % Q.l2, l3 := Q.foo, - }); + }; ''') await self.assert_query_result( @@ -1384,12 +1384,10 @@ async def test_edgeql_insert_for_04(self): l2 := 999, subordinates := ( FOR x IN {('sub1', 'first'), ('sub2', 'second')} - UNION ( - INSERT Subordinate { - name := x.0, - @comment := x.1, - } - ) + INSERT Subordinate { + name := x.0, + @comment := x.1, + } ) }; ''') @@ -1417,9 +1415,9 @@ async def test_edgeql_insert_for_04(self): async def test_edgeql_insert_for_06(self): res = await self.con.query(r''' - FOR a in {"a", "b"} UNION ( - FOR b in {"c", "d"} UNION ( - INSERT Note {name := b})); + FOR a IN {"a", "b"} + FOR b IN {"c", "d"} + INSERT Note {name := b}; ''') self.assertEqual(len(res), 4) @@ -1433,9 +1431,9 @@ async def test_edgeql_insert_for_06(self): async def test_edgeql_insert_for_07(self): res = await self.con.query(r''' - FOR a in {"a", "b"} UNION ( - FOR b in {a++"c", a++"d"} UNION ( - INSERT Note {name := b})); + FOR a IN {"a", "b"} + FOR b IN {a++"c", a++"d"} + INSERT Note {name := b}; ''') self.assertEqual(len(res), 4) @@ -1449,10 +1447,10 @@ async def test_edgeql_insert_for_07(self): async def test_edgeql_insert_for_08(self): res = await self.con.query(r''' - FOR a in {"a", "b"} UNION ( - FOR b in {"a", "b"} UNION ( - FOR c in {a++b++"a", a++b++"b"} UNION ( - INSERT Note {name := c}))); + FOR a IN {"a", "b"} + FOR b IN {"a", "b"} + FOR c IN {a++b++"a", a++b++"b"} + INSERT Note {name := c}; ''') self.assertEqual(len(res), 8) diff --git a/tests/test_edgeql_syntax.py b/tests/test_edgeql_syntax.py index ab31d6b99d8..fc2e2ac9ce8 100644 --- a/tests/test_edgeql_syntax.py +++ b/tests/test_edgeql_syntax.py @@ -3519,6 +3519,19 @@ def test_edgeql_syntax_updatefor_01(self): UNION (UPDATE Foo FILTER (Foo.id = x.0) SET {bar := x.1}); """ + def test_edgeql_syntax_shorterfor_01(self): + """ + FOR x IN {1} + INSERT Foo { x := x }; + """ + + def test_edgeql_syntax_shorterfor_02(self): + """ + FOR x IN 1 + WITH y := x + INSERT Foo { y := y }; + """ + def test_edgeql_syntax_coalesce_01(self): """ SELECT (a ?? x);