diff --git a/edb/pgsql/ast.py b/edb/pgsql/ast.py index 2f0b948215d..b87f18b38a1 100644 --- a/edb/pgsql/ast.py +++ b/edb/pgsql/ast.py @@ -742,8 +742,17 @@ class NullConstant(BaseConstant): nullable: bool = True +class BitStringConstant(BaseConstant): + """A bit string constant.""" + + # x or b + kind: str + + val: str + + class ByteaConstant(BaseConstant): - """An bytea string.""" + """A bytea string.""" val: bytes diff --git a/edb/pgsql/codegen.py b/edb/pgsql/codegen.py index 7cd8876f5d4..ab0a7bf6064 100644 --- a/edb/pgsql/codegen.py +++ b/edb/pgsql/codegen.py @@ -765,6 +765,9 @@ def visit_BooleanConstant(self, node: pgast.BooleanConstant) -> None: def visit_StringConstant(self, node: pgast.StringConstant) -> None: self.write(common.quote_literal(node.val)) + def visit_BitStringConstant(self, node: pgast.BitStringConstant) -> None: + self.write(f"{node.kind}'{node.val}'") + def visit_ByteaConstant(self, node: pgast.ByteaConstant) -> None: self.write(common.quote_bytea_literal(node.val)) diff --git a/edb/pgsql/parser/ast_builder.py b/edb/pgsql/parser/ast_builder.py index 5e5157614a5..71d2b35e039 100644 --- a/edb/pgsql/parser/ast_builder.py +++ b/edb/pgsql/parser/ast_builder.py @@ -900,8 +900,7 @@ def _build_const(n: Node, c: Context) -> pgast.BaseConstant: n = _unwrap(n, 'str') n = _unwrap(n, 'bsval') n = _unwrap(n, 'bsval') - val = bytes.fromhex(n[1:]) - return pgast.ByteaConstant(val=val, span=span) + return pgast.BitStringConstant(kind=n[0], val=n[1:], span=span) raise PSqlUnsupportedError(n) diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index d2646cc6977..8d203e89164 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -503,7 +503,7 @@ async def _store_static_bin_cache_conn( INSERT INTO edgedbinstdata_VER.instdata (key, bin) VALUES( {pg_common.quote_literal(key)}, - {pg_common.quote_bytea_literal(data)}::bytea + {pg_common.quote_bytea_literal(data)} ) """) @@ -1060,7 +1060,7 @@ def prepare_patch( if k in bins: if k not in rawbin: v = pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) - val = f'{pg_common.quote_bytea_literal(v)}::bytea' + val = f'{pg_common.quote_bytea_literal(v)}' sys_updates += (trampoline.fixup_query(f''' INSERT INTO edgedbinstdata_VER.instdata (key, bin) VALUES({key}, {val}) diff --git a/tests/test_sql_query.py b/tests/test_sql_query.py index 2e4c994350e..f42e200bf67 100644 --- a/tests/test_sql_query.py +++ b/tests/test_sql_query.py @@ -772,9 +772,19 @@ async def test_sql_query_40(self): self.assertEqual(res, [[id]]) async def test_sql_query_41(self): - # bytea literal + from asyncpg.types import BitString + + # bit string literal res = await self.squery_values("SELECT x'00abcdef00';") - self.assertEqual(res, [[b'\x00\xab\xcd\xef\x00']]) + self.assertEqual(res, [[BitString.frombytes(b'\x00\xab\xcd\xef\x00')]]) + + res = await self.squery_values("SELECT x'01001ab';") + self.assertEqual( + res, [[BitString.frombytes(b'\x01\x00\x1a\xb0', bitlength=28)]] + ) + + res = await self.squery_values("SELECT b'101';") + self.assertEqual(res, [[BitString.frombytes(b'\xa0', bitlength=3)]]) async def test_sql_query_42(self): # params out of order