diff --git a/.gitmodules b/.gitmodules index d3587af835b..e286e4030fa 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,5 +7,4 @@ url = https://github.com/MagicStack/py-pgproto.git [submodule "edb/pgsql/parser/libpg_query"] path = edb/pgsql/parser/libpg_query - url = https://github.com/pganalyze/libpg_query.git - branch = 13-latest + url = https://github.com/msullivan/libpg_query.git diff --git a/edb/pgsql/parser/ast_builder.py b/edb/pgsql/parser/ast_builder.py index 6de93130e2c..f9ee6ac06e0 100644 --- a/edb/pgsql/parser/ast_builder.py +++ b/edb/pgsql/parser/ast_builder.py @@ -136,8 +136,7 @@ def _build_any(node: Node, _: Context) -> Any: def _build_str(node: Node, _: Context) -> str: - node = _unwrap(node, "String") - node = _unwrap(node, "str") + node = _unwrap_string(node) return str(node) @@ -156,6 +155,42 @@ def _unwrap(node: Node, name: str) -> Node: return node +def _unwrap_boolean(n: Node) -> Node: + n = _unwrap(n, 'Boolean') + n = _unwrap(n, 'str') + n = _unwrap(n, 'boolval') + n = _unwrap(n, 'boolval') + if isinstance(n, dict) and len(n) == 0: + n = False + return n + + +def _unwrap_int(n: Node) -> Node: + n = _unwrap(n, 'Integer') + n = _unwrap(n, 'str') + n = _unwrap(n, 'ival') + n = _unwrap(n, 'ival') + if isinstance(n, dict) and len(n) == 0: + n = 0 + return n + + +def _unwrap_float(n: Node) -> Node: + n = _unwrap(n, 'Float') + n = _unwrap(n, 'str') + n = _unwrap(n, 'fval') + n = _unwrap(n, 'fval') + return n + + +def _unwrap_string(n: Node) -> Node: + n = _unwrap(n, 'String') + n = _unwrap(n, 'str') + n = _unwrap(n, 'sval') + n = _unwrap(n, 'sval') + return n + + def _probe(n: Node, keys: List[str | int]) -> bool: for key in keys: contained = key in n if isinstance(key, str) else key < len(n) @@ -739,20 +774,20 @@ def _build_type_cast(n: Node, c: Context) -> pgast.TypeCast: def _build_type_name(n: Node, c: Context) -> pgast.TypeName: n = _unwrap(n, "TypeName") - def unwrap_int(n: Node, _c: Context): - return _unwrap(_unwrap(n, 'Integer'), 'ival') - name: Tuple[str, ...] = tuple(_list(n, c, "names", _build_str)) # we don't escape char properly, so let's just resolve it during parsing if name == ("char",): name = ("pg_catalog", "char") + def unwrap_int_builder(n: Node, _c: Context) -> Node: + return _unwrap_int(n) + return pgast.TypeName( name=name, setof=_bool_or_false(n, "setof"), typmods=None, - array_bounds=_maybe_list(n, c, "arrayBounds", unwrap_int), + array_bounds=_maybe_list(n, c, "arrayBounds", unwrap_int_builder), span=_build_span(n, c), ) @@ -809,22 +844,25 @@ def _build_base_range_var(n: Node, c: Context) -> pgast.BaseRangeVar: def _build_const(n: Node, c: Context) -> pgast.BaseConstant: - val = n["val"] + n = _unwrap(n, "val") span = _build_span(n, c) - if "Integer" in val: - return pgast.NumericConstant( - val=str(val["Integer"]["ival"]), span=span - ) + if "Null" in n or "isnull" in n: + return pgast.NullConstant(span=span) - if "Float" in val: - return pgast.NumericConstant(val=val["Float"]["str"], span=span) + if "Boolean" in n or "boolval" in n: + return pgast.BooleanConstant(val=_unwrap_boolean(n), span=span) - if "Null" in val: - return pgast.NullConstant(span=span) + if "Integer" in n or "ival" in n: + return pgast.NumericConstant(val=str(_unwrap_int(n)), span=span) + + if "Float" in n or "fval" in n: + return pgast.NumericConstant(val=_unwrap_float(n), span=span) - if "String" in val: - return pgast.StringConstant(val=_build_str(val, c), span=span) + if "String" in n or "sval" in n: + return pgast.StringConstant( + val=_build_str(_unwrap_string(n), c), span=span + ) raise PSqlUnsupportedError(n) diff --git a/edb/pgsql/parser/libpg_query b/edb/pgsql/parser/libpg_query index 1097b2c33e5..c773fdd7100 160000 --- a/edb/pgsql/parser/libpg_query +++ b/edb/pgsql/parser/libpg_query @@ -1 +1 @@ -Subproject commit 1097b2c33e54a37c0d2c0f2d498c7d1cf967eae9 +Subproject commit c773fdd7100c175d0bfbb8be3a79d1f46b370f46 diff --git a/tests/test_sql_parse.py b/tests/test_sql_parse.py index b974f2402a1..bb1d4d870b5 100644 --- a/tests/test_sql_parse.py +++ b/tests/test_sql_parse.py @@ -449,6 +449,21 @@ def test_sql_parse_select_57(self): SELECT * FROM t_20210301_x """ + def test_sql_parse_select_58(self): + """ + SELECT (1.2 * 3.4) + """ + + def test_sql_parse_select_59(self): + """ + SELECT TRUE; SELECT FALSE + """ + + def test_sql_parse_select_60(self): + """ + SELECT -1; SELECT 0; SELECT 1 + """ + def test_sql_parse_insert_00(self): """ INSERT INTO my_table (id, name) VALUES (1, 'some')