Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bit string literals over SQL adapter #8112

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion edb/pgsql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,8 +742,17 @@ class NullConstant(BaseConstant):
nullable: bool = True


class BitStringConstant(BaseConstant):
"""A bit string constant."""

# x or b
kind: str

val: str


class ByteaConstant(BaseConstant):
"""An bytea string."""
"""A bytea string."""

val: bytes

Expand Down
3 changes: 3 additions & 0 deletions edb/pgsql/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,9 @@ def visit_BooleanConstant(self, node: pgast.BooleanConstant) -> None:
def visit_StringConstant(self, node: pgast.StringConstant) -> None:
self.write(common.quote_literal(node.val))

def visit_BitStringConstant(self, node: pgast.BitStringConstant) -> None:
self.write(f"{node.kind}'{node.val}'")

def visit_ByteaConstant(self, node: pgast.ByteaConstant) -> None:
self.write(common.quote_bytea_literal(node.val))

Expand Down
3 changes: 1 addition & 2 deletions edb/pgsql/parser/ast_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,7 @@ def _build_const(n: Node, c: Context) -> pgast.BaseConstant:
n = _unwrap(n, 'str')
n = _unwrap(n, 'bsval')
n = _unwrap(n, 'bsval')
val = bytes.fromhex(n[1:])
return pgast.ByteaConstant(val=val, span=span)
return pgast.BitStringConstant(kind=n[0], val=n[1:], span=span)
raise PSqlUnsupportedError(n)


Expand Down
4 changes: 2 additions & 2 deletions edb/server/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ async def _store_static_bin_cache_conn(
INSERT INTO edgedbinstdata_VER.instdata (key, bin)
VALUES(
{pg_common.quote_literal(key)},
{pg_common.quote_bytea_literal(data)}::bytea
{pg_common.quote_bytea_literal(data)}
)
""")

Expand Down Expand Up @@ -1060,7 +1060,7 @@ def prepare_patch(
if k in bins:
if k not in rawbin:
v = pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL)
val = f'{pg_common.quote_bytea_literal(v)}::bytea'
val = f'{pg_common.quote_bytea_literal(v)}'
sys_updates += (trampoline.fixup_query(f'''
INSERT INTO edgedbinstdata_VER.instdata (key, bin)
VALUES({key}, {val})
Expand Down
14 changes: 12 additions & 2 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,9 +772,19 @@ async def test_sql_query_40(self):
self.assertEqual(res, [[id]])

async def test_sql_query_41(self):
# bytea literal
from asyncpg.types import BitString

# bit string literal
res = await self.squery_values("SELECT x'00abcdef00';")
self.assertEqual(res, [[b'\x00\xab\xcd\xef\x00']])
self.assertEqual(res, [[BitString.frombytes(b'\x00\xab\xcd\xef\x00')]])

res = await self.squery_values("SELECT x'01001ab';")
self.assertEqual(
res, [[BitString.frombytes(b'\x01\x00\x1a\xb0', bitlength=28)]]
)

res = await self.squery_values("SELECT b'101';")
self.assertEqual(res, [[BitString.frombytes(b'\xa0', bitlength=3)]])

async def test_sql_query_42(self):
# params out of order
Expand Down
Loading