From a763e4804def2c5d10b4dc358f342c0159ffe312 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Sat, 17 Aug 2024 09:46:22 -0600 Subject: [PATCH] Rust postgres client --- Cargo.lock | 89 ++ Cargo.toml | 2 + connect.py | 39 + edb/server/bootstrap.py | 31 +- edb/server/cluster.py | 26 +- edb/server/inplace_upgrade.py | 20 +- edb/server/main.py | 62 +- edb/server/multitenant.py | 17 +- edb/server/pgcluster.py | 188 ++-- edb/server/pgcon/__init__.py | 7 +- edb/server/pgcon/__init__.pyi | 10 +- edb/server/pgcon/connect.py | 213 ++++ edb/server/pgcon/pgcon.pxd | 12 +- edb/server/pgcon/pgcon.pyx | 557 +--------- edb/server/pgcon/rust_transport.py | 473 +++++++++ edb/server/pgcon/scram.pxd | 46 - edb/server/pgcon/scram.pyx | 370 ------- edb/server/pgconnparams.py | 287 +++-- edb/server/pgrust/Cargo.toml | 17 +- edb/server/pgrust/__init__.pyi | 2 + edb/server/pgrust/examples/connect.rs | 176 ++++ edb/server/pgrust/examples/dsn.rs | 8 +- edb/server/pgrust/src/auth/md5.rs | 50 + edb/server/pgrust/src/auth/mod.rs | 11 + edb/server/pgrust/src/auth/scram.rs | 815 +++++++++++++++ edb/server/pgrust/src/auth/stringprep.rs | 209 ++++ .../pgrust/src/auth/stringprep_table.rs | 801 ++++++++++++++ .../pgrust/src/auth/stringprep_table_prep.py | 43 + edb/server/pgrust/src/conn_string.rs | 978 ------------------ edb/server/pgrust/src/connection/conn.rs | 406 ++++++++ edb/server/pgrust/src/connection/dsn.rs | 817 +++++++++++++++ edb/server/pgrust/src/connection/mod.rs | 134 +++ edb/server/pgrust/src/connection/openssl.rs | 121 +++ edb/server/pgrust/src/connection/params.rs | 412 ++++++++ edb/server/pgrust/src/connection/raw_conn.rs | 240 +++++ .../pgrust/src/connection/raw_params.rs | 553 ++++++++++ .../pgrust/src/connection/state_machine.rs | 417 ++++++++ edb/server/pgrust/src/connection/stream.rs | 195 ++++ edb/server/pgrust/src/connection/tokio.rs | 136 +++ edb/server/pgrust/src/lib.rs | 8 +- edb/server/pgrust/src/protocol/arrays.rs | 346 +++++++ edb/server/pgrust/src/protocol/buffer.rs | 224 ++++ edb/server/pgrust/src/protocol/datatypes.rs | 508 +++++++++ edb/server/pgrust/src/protocol/definition.rs | 740 +++++++++++++ edb/server/pgrust/src/protocol/gen.rs | 776 ++++++++++++++ .../pgrust/src/protocol/message_group.rs | 123 +++ edb/server/pgrust/src/protocol/mod.rs | 392 +++++++ edb/server/pgrust/src/protocol/writer.rs | 97 ++ edb/server/pgrust/src/python.rs | 425 +++++++- edb/server/pgrust/tests/edgedb_test_cases.rs | 285 +++++ .../pgrust/tests/hardcore_host_tests_cases.rs | 168 +++ edb/server/pgrust/tests/libpq_test_cases.rs | 329 ++++++ edb/server/pgrust/tests/real_postgres.rs | 375 +++++++ .../pgrust/tests/test_util/dsn_libpq.rs | 256 +++++ edb/server/pgrust/tests/test_util/mod.rs | 204 ++++ edb/server/protocol/binary.pyx | 12 +- edb/server/render_dsn.py | 57 - edb/server/tenant.py | 36 +- edb/testbase/connection.py | 6 +- edb/testbase/server.py | 17 +- edb/tools/wipe.py | 14 +- pyproject.toml | 1 + tests/test_backend_connect.py | 677 +----------- tests/test_server_ops.py | 68 +- 64 files changed, 12055 insertions(+), 3079 deletions(-) create mode 100644 connect.py create mode 100644 edb/server/pgcon/connect.py create mode 100644 edb/server/pgcon/rust_transport.py delete mode 100644 edb/server/pgcon/scram.pxd delete mode 100644 edb/server/pgcon/scram.pyx create mode 100644 edb/server/pgrust/__init__.pyi create mode 100644 edb/server/pgrust/examples/connect.rs create mode 100644 edb/server/pgrust/src/auth/md5.rs create mode 100644 edb/server/pgrust/src/auth/mod.rs create mode 100644 edb/server/pgrust/src/auth/scram.rs create mode 100644 edb/server/pgrust/src/auth/stringprep.rs create mode 100644 edb/server/pgrust/src/auth/stringprep_table.rs create mode 100644 edb/server/pgrust/src/auth/stringprep_table_prep.py delete mode 100644 edb/server/pgrust/src/conn_string.rs create mode 100644 edb/server/pgrust/src/connection/conn.rs create mode 100644 edb/server/pgrust/src/connection/dsn.rs create mode 100644 edb/server/pgrust/src/connection/mod.rs create mode 100644 edb/server/pgrust/src/connection/openssl.rs create mode 100644 edb/server/pgrust/src/connection/params.rs create mode 100644 edb/server/pgrust/src/connection/raw_conn.rs create mode 100644 edb/server/pgrust/src/connection/raw_params.rs create mode 100644 edb/server/pgrust/src/connection/state_machine.rs create mode 100644 edb/server/pgrust/src/connection/stream.rs create mode 100644 edb/server/pgrust/src/connection/tokio.rs create mode 100644 edb/server/pgrust/src/protocol/arrays.rs create mode 100644 edb/server/pgrust/src/protocol/buffer.rs create mode 100644 edb/server/pgrust/src/protocol/datatypes.rs create mode 100644 edb/server/pgrust/src/protocol/definition.rs create mode 100644 edb/server/pgrust/src/protocol/gen.rs create mode 100644 edb/server/pgrust/src/protocol/message_group.rs create mode 100644 edb/server/pgrust/src/protocol/mod.rs create mode 100644 edb/server/pgrust/src/protocol/writer.rs create mode 100644 edb/server/pgrust/tests/edgedb_test_cases.rs create mode 100644 edb/server/pgrust/tests/hardcore_host_tests_cases.rs create mode 100644 edb/server/pgrust/tests/libpq_test_cases.rs create mode 100644 edb/server/pgrust/tests/real_postgres.rs create mode 100644 edb/server/pgrust/tests/test_util/dsn_libpq.rs create mode 100644 edb/server/pgrust/tests/test_util/mod.rs delete mode 100644 edb/server/render_dsn.py diff --git a/Cargo.lock b/Cargo.lock index 9aea97518bb9..39b823bde50a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -598,6 +598,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys", +] + [[package]] name = "factorial" version = "0.2.1" @@ -607,6 +617,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "fastrand" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" + [[package]] name = "foreign-types" version = "0.3.2" @@ -812,6 +828,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "hex-literal" version = "0.4.1" @@ -916,6 +938,12 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + [[package]] name = "log" version = "0.4.22" @@ -950,6 +978,12 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.7.4" @@ -1194,17 +1228,21 @@ dependencies = [ "consume_on_drop", "derive_more", "futures", + "hex", "hex-literal", "hexdump", "hmac", "itertools 0.13.0", + "libc", "lru", + "md5", "openssl", "paste", "percent-encoding", "pretty_assertions", "pyo3", "rand", + "roaring", "rstest", "scopeguard", "serde", @@ -1212,12 +1250,15 @@ dependencies = [ "serde_derive", "sha2", "smart-default", + "socket2", "statrs", "stringprep", "strum", + "tempfile", "test-log", "thiserror", "tokio", + "tokio-openssl", "tracing", "tracing-subscriber", "unicode-normalization", @@ -1516,6 +1557,16 @@ version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" +[[package]] +name = "roaring" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f4b84ba6e838ceb47b41de5194a60244fac43d9fe03b71dbe8c5a201081d6d1" +dependencies = [ + "bytemuck", + "byteorder", +] + [[package]] name = "rstest" version = "0.22.0" @@ -1561,6 +1612,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "0.38.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustversion" version = "1.0.17" @@ -1825,6 +1889,19 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "tempfile" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fcd239983515c23a32fb82099f97d0b11b8c72f654ed659363a95c3dad7a53" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "test-log" version = "0.2.16" @@ -1919,6 +1996,18 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-openssl" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ffab79df67727f6acf57f1ff743091873c24c579b1e2ce4d8f53e47ded4d63d" +dependencies = [ + "futures-util", + "openssl", + "openssl-sys", + "tokio", +] + [[package]] name = "toml_datetime" version = "0.6.8" diff --git a/Cargo.toml b/Cargo.toml index 44deebe15c3e..f8cfaa543b49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,8 @@ resolver = "2" [workspace.dependencies] pyo3 = { version = "0.22.2", features = ["extension-module", "serde"] } tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "time", "sync", "net", "io-util"] } +tracing = "0.1.40" +tracing-subscriber = "0.3.18" [profile.release] debug = true diff --git a/connect.py b/connect.py new file mode 100644 index 000000000000..9e9b786f3c45 --- /dev/null +++ b/connect.py @@ -0,0 +1,39 @@ +from edb.server.pgcon.rust_transport import create_postgres_connection +import asyncio +import time + +class MyProtocol(asyncio.Protocol): + def __init__(self): + self.closed = asyncio.Future() + + def connection_made(self, transport): + print(f"Connection made: {transport.connection}") + self.transport = transport + + def data_received(self, data): + print(f"Received: {data}") + self.transport.close() + + def connection_lost(self, exc): + print(f"Connection lost: {exc}") + self.closed.set_result(None) + +async def main(): + now = time.perf_counter_ns() + transport, protocol = await create_postgres_connection( + "postgres://user:password@localhost/postgres", + lambda: MyProtocol(), + state_change_callback=lambda state: print(f"Connection: {state.name}")) + print(f"Connection time: {(time.perf_counter_ns() - now) // 1000}µs") + print(f"Connected: {transport}") + print(f"Peer: {transport.get_extra_info('peername')}") + print(f"Cipher: {transport.get_extra_info('cipher')}") + + # Send a simple query message + query = b'SELECT version();\0' + message = b'Q' + (len(query) + 4).to_bytes(4, 'big') + query + transport.write(message) + + await protocol.closed + +asyncio.run(main()) diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index d1e6e5dba2a6..df567e5ee308 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -118,6 +118,7 @@ class PGConnectionProxy: def __init__( self, cluster: pgcluster.BaseCluster, + source_description: str, dbname: Optional[str] = None, log_listener: Optional[Callable[[str, str], None]] = None, ): @@ -125,15 +126,21 @@ def __init__( self._cluster = cluster self._dbname = dbname self._log_listener = log_listener or _pg_log_listener + self._source_description = source_description async def connect(self) -> None: if self._conn is not None: self._conn.terminate() if self._dbname: - self._conn = await self._cluster.connect(database=self._dbname) + self._conn = await self._cluster.connect( + source_description=self._source_description, + database=self._dbname + ) else: - self._conn = await self._cluster.connect() + self._conn = await self._cluster.connect( + source_description=self._source_description + ) if self._log_listener is not None: self._conn.add_log_listener(self._log_listener) @@ -2310,7 +2317,11 @@ async def _check_catalog_compatibility( ) ) - conn = PGConnectionProxy(ctx.cluster, sys_db.decode("utf-8")) + conn = PGConnectionProxy( + ctx.cluster, + source_description="_check_catalog_compatibility", + dbname=sys_db.decode("utf-8") + ) try: instancedata = await _get_instance_data(conn) @@ -2473,7 +2484,11 @@ async def _bootstrap( else: new_template_db_id = uuidgen.uuid1mc() tpl_db = cluster.get_db_name(edbdef.EDGEDB_TEMPLATE_DB) - conn = PGConnectionProxy(cluster, tpl_db) + conn = PGConnectionProxy( + cluster, + source_description="_bootstrap", + dbname=tpl_db + ) tpl_ctx = dataclasses.replace(ctx, conn=conn) else: @@ -2596,7 +2611,8 @@ async def _bootstrap( sys_conn = PGConnectionProxy( cluster, - cluster.get_db_name(edbdef.EDGEDB_SYSTEM_DB), + source_description="_bootstrap", + dbname=cluster.get_db_name(edbdef.EDGEDB_SYSTEM_DB), ) try: @@ -2667,7 +2683,10 @@ async def ensure_bootstrapped( Returns True if bootstrap happened and False if the instance was already bootstrapped, along with the bootstrap compiler state. """ - pgconn = PGConnectionProxy(cluster) + pgconn = PGConnectionProxy( + cluster, + source_description="ensure_bootstrapped" + ) ctx = BootstrapContext(cluster=cluster, conn=pgconn, args=args) try: diff --git a/edb/server/cluster.py b/edb/server/cluster.py index 1116eec07d7d..908829489df4 100644 --- a/edb/server/cluster.py +++ b/edb/server/cluster.py @@ -38,6 +38,7 @@ from edb.server import args as edgedb_args from edb.server import defines as edgedb_defines +from edb.server import pgconnparams from . import pgcluster @@ -118,7 +119,7 @@ def __init__( self._runstate_dir = runstate_dir self._edgedb_cmd.extend(['--runstate-dir', str(runstate_dir)]) self._pg_cluster: Optional[pgcluster.BaseCluster] = None - self._pg_connect_args: Dict[str, Any] = {} + self._pg_connect_args: pgconnparams.CreateParamsKwargs = {} self._daemon_process: Optional[subprocess.Popen[str]] = None self._port = port self._effective_port = None @@ -146,7 +147,7 @@ async def get_status(self) -> str: conn = None try: conn = await pg_cluster.connect( - timeout=5, + source_description=f"{self.__class__}.get_status", **self._pg_connect_args, ) @@ -314,37 +315,34 @@ async def test() -> None: started = time.monotonic() await test() left -= (time.monotonic() - started) - - if self._admin_query("SELECT ();", f"{max(1, int(left))}s"): + if res := self._admin_query("SELECT ();", f"{max(1, int(left))}s"): raise ClusterError( f'could not connect to edgedb-server ' - f'within {timeout} seconds') from None + f'within {timeout} seconds (exit code = {res})') from None def _admin_query( self, query: str, wait_until_available: str = "0s", ) -> int: - return subprocess.call( - [ + args = [ "edgedb", - "--host", + "query", + "--unix-path", str(os.path.abspath(self._runstate_dir)), "--port", str(self._effective_port), "--admin", "--user", edgedb_defines.EDGEDB_SUPERUSER, - "--database", + "--branch", edgedb_defines.EDGEDB_SUPERUSER_DB, "--wait-until-available", wait_until_available, - "-c", query, - ], - stdout=subprocess.DEVNULL, - stderr=subprocess.STDOUT, - ) + ] + res = subprocess.call(args=args) + return res async def set_test_config(self) -> None: self._admin_query(f''' diff --git a/edb/server/inplace_upgrade.py b/edb/server/inplace_upgrade.py index 08a8e55d5acf..8f1eb435e668 100644 --- a/edb/server/inplace_upgrade.py +++ b/edb/server/inplace_upgrade.py @@ -434,7 +434,10 @@ async def _get_databases( cluster = ctx.cluster tpl_db = cluster.get_db_name(edbdef.EDGEDB_TEMPLATE_DB) - conn = await cluster.connect(database=tpl_db) + conn = await cluster.connect( + source_description="inplace upgrade", + database=tpl_db + ) # FIXME: Use the sys query instead? try: @@ -500,7 +503,10 @@ async def _upgrade_all( continue conn = bootstrap.PGConnectionProxy( - cluster, cluster.get_db_name(database)) + cluster, + source_description="inplace upgrade: upgrade all", + dbname=cluster.get_db_name(database) + ) try: subctx = dataclasses.replace(ctx, conn=conn) @@ -529,7 +535,10 @@ async def go( inject_failure_on: Optional[str]=None, ) -> None: for database in databases: - conn = await cluster.connect(database=cluster.get_db_name(database)) + conn = await cluster.connect( + source_description="inplace upgrade: finish", + database=cluster.get_db_name(database) + ) try: subctx = dataclasses.replace(ctx, conn=conn) @@ -567,7 +576,10 @@ async def inplace_upgrade( args: edbargs.ServerConfig, ) -> None: """Perform some or all of the inplace upgrade operations""" - pgconn = bootstrap.PGConnectionProxy(cluster) + pgconn = bootstrap.PGConnectionProxy( + cluster, + source_description="inplace_upgrade" + ) ctx = bootstrap.BootstrapContext(cluster=cluster, conn=pgconn, args=args) try: diff --git a/edb/server/main.py b/edb/server/main.py index 9b9309e703bf..6d7aaa306624 100644 --- a/edb/server/main.py +++ b/edb/server/main.py @@ -36,7 +36,6 @@ import asyncio import contextlib -import dataclasses import logging import os import os.path @@ -65,7 +64,6 @@ from . import daemon from . import defines from . import logsetup -from . import pgconnparams from . import pgcluster from . import service_manager @@ -362,14 +360,12 @@ async def _get_local_pgcluster( tenant_id=tenant_id, log_level=args.log_level, ) - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='template1', - server_settings={ - "application_name": f'edgedb_instance_{args.instance_name}', - } - ), + cluster.update_connection_params( + user='postgres', + database='template1', + server_settings={ + "application_name": f'edgedb_instance_{args.instance_name}', + } ) return cluster, args @@ -400,15 +396,9 @@ async def _get_remote_pgcluster( abort(f'--max-backend-connections is too large for this backend; ' f'detected maximum available NUM: {max_conns}') - conn_params = cluster.get_connection_params() - conn_params = dataclasses.replace( - conn_params, - server_settings=dict( - conn_params.server_settings, - application_name=f'edgedb_instance_{args.instance_name}', - ), - ) - cluster.set_connection_params(conn_params) + cluster.update_connection_params(server_settings={ + 'application_name': f'edgedb_instance_{args.instance_name}' + }) return cluster, args @@ -638,25 +628,25 @@ async def run_server( is srvargs.JOSEKeyMode.Generate ) ): - conn_params = cluster.get_connection_params() instance_name = args.instance_name - conn_params = dataclasses.replace( - conn_params, - server_settings={ - **conn_params.server_settings, - **backend_settings, - 'application_name': f'edgedb_instance_{instance_name}', - 'edgedb.instance_name': instance_name, - 'edgedb.server_version': buildmeta.get_version_json(), - }, - ) - if args.data_dir: - conn_params.database = pgcluster.get_database_backend_name( - defines.EDGEDB_TEMPLATE_DB, - tenant_id=tenant_id, + database = pgcluster.get_database_backend_name( + defines.EDGEDB_TEMPLATE_DB, + tenant_id=tenant_id, + ) if args.data_dir else None + server_settings = { + 'application_name': f'edgedb_instance_{instance_name}', + 'edgedb.instance_name': instance_name, + 'edgedb.server_version': buildmeta.get_version_json(), + } + if database: + cluster.update_connection_params( + database=database, + server_settings=server_settings + ) + else: + cluster.update_connection_params( + server_settings=server_settings ) - - cluster.set_connection_params(conn_params) with _internal_state_dir(runstate_dir, args) as ( int_runstate_dir, diff --git a/edb/server/multitenant.py b/edb/server/multitenant.py index e215a6b8a09b..f3b5f910b05e 100644 --- a/edb/server/multitenant.py +++ b/edb/server/multitenant.py @@ -21,7 +21,6 @@ import asyncio import collections -import dataclasses import json import logging import pathlib @@ -225,17 +224,11 @@ async def _create_tenant(self, conf: TenantConfig) -> edbtenant.Tenant: else: max_conns = conf["max-backend-connections"] - conn_params = cluster.get_connection_params() - conn_params = dataclasses.replace( - conn_params, - server_settings={ - **conn_params.server_settings, - "application_name": f'edgedb_instance_{conf["instance-name"]}', - "edgedb.instance_name": conf["instance-name"], - "edgedb.server_version": buildmeta.get_version_json(), - }, - ) - cluster.set_connection_params(conn_params) + cluster.update_connection_params(server_settings={ + "application_name": f'edgedb_instance_{conf["instance-name"]}', + "edgedb.instance_name": conf["instance-name"], + "edgedb.server_version": buildmeta.get_version_json(), + }) tenant = edbtenant.Tenant( cluster, diff --git a/edb/server/pgcluster.py b/edb/server/pgcluster.py index a604156a52a5..c2965ab68d3a 100644 --- a/edb/server/pgcluster.py +++ b/edb/server/pgcluster.py @@ -27,6 +27,7 @@ Mapping, Sequence, Coroutine, + Unpack, Dict, List, cast, @@ -56,16 +57,14 @@ from edb.server import args as srvargs from edb.server import defines +from edb.server import pgconnparams from edb.server.ha import base as ha_base from edb.pgsql import common as pgcommon from edb.pgsql import params as pgparams -from . import pgconnparams - if TYPE_CHECKING: from edb.server import pgcon - logger = logging.getLogger('edb.pgcluster') pg_dump_logger = logging.getLogger('pg_dump') pg_restore_logger = logging.getLogger('pg_restore') @@ -77,6 +76,15 @@ get_database_backend_name = pgcommon.get_database_backend_name get_role_backend_name = pgcommon.get_role_backend_name +EDGEDB_SERVER_SETTINGS = { + 'client_encoding': 'utf-8', + 'search_path': 'edgedb', + 'timezone': 'UTC', + 'intervalstyle': 'iso_8601', + 'jit': 'off', + 'default_transaction_isolation': 'serializable', +} + class ClusterError(Exception): pass @@ -94,9 +102,8 @@ def __init__( instance_params: Optional[pgparams.BackendInstanceParams] = None, ) -> None: self._connection_addr: Optional[Tuple[str, int]] = None - self._connection_params: Optional[ - pgconnparams.ConnectionParameters - ] = None + self._connection_params: pgconnparams.ConnectionParams = \ + pgconnparams.ConnectionParams(server_settings=EDGEDB_SERVER_SETTINGS) self._pg_config_data: Dict[str, str] = {} self._pg_bin_dir: Optional[pathlib.Path] = None if instance_params is None: @@ -129,7 +136,9 @@ def get_role_name(self, role_name: str) -> str: assert ( role_name == defines.EDGEDB_SUPERUSER ), f"role_name={role_name} is not allowed" - return self.get_connection_params().user + rv = self.get_connection_params().user + assert rv is not None + return rv return get_database_backend_name( role_name, @@ -151,19 +160,29 @@ async def stop(self, wait: int = 60) -> None: def destroy(self) -> None: raise NotImplementedError - async def connect(self, **kwargs: Any) -> pgcon.PGConnection: + async def connect(self, + *, + source_description: str, + apply_init_script: bool = False, + **kwargs: Unpack[pgconnparams.CreateParamsKwargs] + ) -> pgcon.PGConnection: + """Connect to this cluster, with optional overriding parameters. If + overriding parameters are specified, they are applied to a copy of the + connection parameters before the connection takes place.""" from edb.server import pgcon - conn_info = self.get_connection_spec() - conn_info.update(kwargs) - dbname = conn_info.get("database") or conn_info.get("user") - assert isinstance(dbname, str) - return await pgcon.connect( - conn_info, - dbname=dbname, + connection = self.get_connection_params().clone() + addr = self._get_connection_addr() + assert addr is not None + connection.update(hosts=[addr]) + connection.update(**kwargs) + conn = await pgcon.connect( + connection, + source_description=source_description, backend_params=self.get_runtime_params(), - apply_init_script=False, + apply_init_script=apply_init_script, ) + return conn async def start_watching( self, failover_cb: Optional[Callable[[], None]] = None @@ -191,56 +210,26 @@ def overwrite_capabilities( capabilities=caps ) - def get_connection_addr(self) -> Optional[Tuple[str, int]]: - return self._get_connection_addr() - - def set_connection_params( + def update_connection_params( self, - params: pgconnparams.ConnectionParameters, + **kwargs: Unpack[pgconnparams.CreateParamsKwargs], ) -> None: - self._connection_params = params + self._connection_params.update(**kwargs) + + def get_pgaddr(self) -> str: + assert self._connection_params is not None + addr = self._get_connection_addr() + assert addr is not None + params = self._connection_params.clone() + params.update(hosts=[addr]) + return params.to_dsn() def get_connection_params( self, - ) -> pgconnparams.ConnectionParameters: + ) -> pgconnparams.ConnectionParams: assert self._connection_params is not None return self._connection_params - def get_connection_spec(self) -> Dict[str, Any]: - conn_dict: Dict[str, Any] = {} - addr = self.get_connection_addr() - assert addr is not None - conn_dict['host'] = addr[0] - conn_dict['port'] = addr[1] - params = self.get_connection_params() - for k in ( - 'user', - 'password', - 'database', - 'ssl', - 'sslmode', - 'server_settings', - 'connect_timeout', - ): - v = getattr(params, k) - if v is not None: - conn_dict[k] = v - - cluster_settings = conn_dict.get('server_settings', {}) - - edgedb_settings = { - 'client_encoding': 'utf-8', - 'search_path': 'edgedb', - 'timezone': 'UTC', - 'intervalstyle': 'iso_8601', - 'jit': 'off', - 'default_transaction_isolation': 'serializable', - } - - conn_dict['server_settings'] = {**cluster_settings, **edgedb_settings} - - return conn_dict - def _get_connection_addr(self) -> Optional[Tuple[str, int]]: return self._connection_addr @@ -254,18 +243,21 @@ def _dump_restore_conn_args( self, dbname: str, ) -> tuple[list[str], dict[str, str]]: - conn_spec = self.get_connection_spec() + params = self.get_connection_params().clone() + addr = self._get_connection_addr() + assert addr is not None + params.update(database=dbname, hosts=[addr]) args = [ - f'--dbname={dbname}', - f'--host={conn_spec["host"]}', - f'--port={conn_spec["port"]}', - f'--username={conn_spec["user"]}', + f'--dbname={params.database}', + f'--host={params.host}', + f'--port={params.port}', + f'--username={params.user}', ] env = os.environ.copy() - if conn_spec.get("password"): - env['PGPASSWORD'] = conn_spec["password"] + if params.password: + env['PGPASSWORD'] = params.password return args, env @@ -313,7 +305,6 @@ async def dump_database( for flag, vals in configs: for val in vals: args.append(f'--{flag}={val}') - stdout_lines, _, _ = await _run_logged_subprocess( args, logger=pg_dump_logger, @@ -776,6 +767,10 @@ async def _test_connection(self, timeout: int = 60) -> str: self._connection_addr = None connected = False + params = pgconnparams.ConnectionParams( + user="postgres", + database="postgres") + for n in range(timeout + 9): # pg usually comes up pretty quickly, but not so quickly # that we don't hit the wait case. Make our first several @@ -802,16 +797,11 @@ async def _test_connection(self, timeout: int = 60) -> str: continue try: + params.update(hosts=[conn_addr]) con = await asyncio.wait_for( pgcon.connect( - dbname="postgres", - connargs={ - "user": "postgres", - "database": "postgres", - "host": conn_addr[0], - "port": conn_addr[1], - "server_settings": {}, - }, + params, + source_description=f"{self.__class__}._test_connection", backend_params=self.get_runtime_params(), apply_init_script=False, ), @@ -821,7 +811,9 @@ async def _test_connection(self, timeout: int = 60) -> str: OSError, asyncio.TimeoutError, pgcon.BackendConnectionError, - ): + ) as e: + if n % 10 == 0 and n > 0: + logger.error("_test_connection failed", e) await asyncio.sleep(sleep_time) continue except pgcon.BackendError: @@ -843,15 +835,18 @@ async def _test_connection(self, timeout: int = 60) -> str: class RemoteCluster(BaseCluster): def __init__( self, - addr: Tuple[str, int], - params: pgconnparams.ConnectionParameters, *, + connection_addr: tuple[str, int], + connection_params: pgconnparams.ConnectionParams, instance_params: Optional[pgparams.BackendInstanceParams] = None, ha_backend: Optional[ha_base.HABackend] = None, ): super().__init__(instance_params=instance_params) - self._connection_addr = addr - self._connection_params = params + self._connection_params = connection_params + self._connection_params.update( + server_settings=EDGEDB_SERVER_SETTINGS + ) + self._connection_addr = connection_addr self._ha_backend = ha_backend def _get_connection_addr(self) -> Optional[Tuple[str, int]]: @@ -987,6 +982,7 @@ async def get_remote_pg_cluster( tenant_id: Optional[str] = None, specified_capabilities: Optional[srvargs.BackendCapabilitySets] = None, ) -> RemoteCluster: + from edb.server import pgcon parsed = urllib.parse.urlparse(dsn) ha_backend = None @@ -1019,14 +1015,10 @@ async def get_remote_pg_cluster( if query: dsn += f"?{urllib.parse.urlencode(query)}" - addrs, params = pgconnparams.parse_dsn(dsn) - if len(addrs) > 1: - raise ValueError('multiple hosts in Postgres DSN are not supported') if tenant_id is None: t_id = buildmeta.get_default_tenant_id() else: t_id = tenant_id - rcluster = RemoteCluster(addrs[0], params) async def _get_cluster_type( conn: pgcon.PGConnection, @@ -1194,7 +1186,16 @@ async def _get_reserved_connections( rv += int(value) return rv - conn = await rcluster.connect() + probe_connection = pgconnparams.ConnectionParams(dsn=dsn) + conn = await pgcon.connect( + probe_connection, + source_description="remote cluster probe", + backend_params=pgparams.get_default_runtime_params(), + apply_init_script=False + ) + params = conn.connection + addr = conn.addr + try: data = json.loads(await conn.sql_fetch_val( b""" @@ -1208,12 +1209,17 @@ async def _get_reserved_connections( ) )""", )) - user = data["user"] - dbname = data["dbname"] + params.update( + user=data["user"], + database=data["dbname"] + ) cluster_type, superuser_name = await _get_cluster_type(conn) max_connections = data["connlimit"] - if max_connections == -1: - max_connections = await _get_pg_settings(conn, 'max_connections') + pg_max_connections = await _get_pg_settings(conn, 'max_connections') + if max_connections == -1 or not isinstance(max_connections, int): + max_connections = pg_max_connections + else: + max_connections = min(max_connections, pg_max_connections) capabilities = await _detect_capabilities(conn) if ( @@ -1299,11 +1305,9 @@ async def _get_reserved_connections( finally: conn.terminate() - params.user = user - params.database = dbname return cluster_type( - addrs[0], - params, + connection_addr=addr, + connection_params=params, instance_params=instance_params, ha_backend=ha_backend, ) diff --git a/edb/server/pgcon/__init__.py b/edb/server/pgcon/__init__.py index 79fbc6621629..9006fd122c3f 100644 --- a/edb/server/pgcon/__init__.py +++ b/edb/server/pgcon/__init__.py @@ -27,8 +27,13 @@ ) from .pgcon import ( - connect, PGConnection, SETUP_TEMP_TABLE_SCRIPT, SETUP_CONFIG_CACHE_SCRIPT, + PGConnection, +) +from .connect import ( + connect, set_init_con_script_data, + SETUP_TEMP_TABLE_SCRIPT, + SETUP_CONFIG_CACHE_SCRIPT ) __all__ = ( diff --git a/edb/server/pgcon/__init__.pyi b/edb/server/pgcon/__init__.pyi index ca4d965e3fa6..ebd81b6a3378 100644 --- a/edb/server/pgcon/__init__.pyi +++ b/edb/server/pgcon/__init__.pyi @@ -21,11 +21,12 @@ from __future__ import annotations from typing import ( Any, Callable, - Dict, Optional, ) from edb.pgsql import params as pg_params +from edb.server import pgconnparams +from .rust_transport import ConnectionParams class BackendError(Exception): @@ -45,9 +46,10 @@ class BackendCatalogNameError(BackendError): ... async def connect( - connargs: Dict[str, Any], - dbname: str, + dsn_or_connection: str | ConnectionParams, + *, backend_params: pg_params.BackendRuntimeParams, + source_description: str, apply_init_script: bool = True, ) -> PGConnection: ... @@ -59,6 +61,8 @@ class PGConnection: idle: bool backend_pid: int + connection: pgconnparams.ConnectionParams + addr: tuple[str, int] async def sql_execute(self, sql: bytes | tuple[bytes, ...]) -> None: ... diff --git a/edb/server/pgcon/connect.py b/edb/server/pgcon/connect.py new file mode 100644 index 000000000000..01a5296bc060 --- /dev/null +++ b/edb/server/pgcon/connect.py @@ -0,0 +1,213 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import logging +import textwrap + +from edb.server.pgconnparams import SSLMode +from edb.pgsql.common import quote_ident as pg_qi +from edb.pgsql.common import quote_literal as pg_ql +from edb.pgsql import params as pg_params +from . import errors as pgerror +from .rust_transport import ConnectionParams, create_postgres_connection +from .pgcon import PGConnection + +logger = logging.getLogger('edb.server') + +INIT_CON_SCRIPT = None +INIT_CON_SCRIPT_DATA = '' + +# The '_edgecon_state table' is used to store information about +# the current session. The `type` column is one character, with one +# of the following values: +# +# * 'C': a session-level config setting +# +# * 'B': a session-level config setting that's implemented by setting +# a corresponding Postgres config setting. +# * 'A': an instance-level config setting from command-line arguments +# * 'E': an instance-level config setting from environment variable +SETUP_TEMP_TABLE_SCRIPT = ''' + CREATE TEMPORARY TABLE _edgecon_state ( + name text NOT NULL, + value jsonb NOT NULL, + type text NOT NULL CHECK( + type = 'C' OR type = 'B' OR type = 'A' OR type = 'E'), + UNIQUE(name, type) + ); +''' +SETUP_CONFIG_CACHE_SCRIPT = ''' + CREATE TEMPORARY TABLE _config_cache ( + source edgedb._sys_config_source_t, + value edgedb._sys_config_val_t NOT NULL + ); +''' + + +def _build_init_con_script(*, check_pg_is_in_recovery: bool) -> bytes: + if check_pg_is_in_recovery: + pg_is_in_recovery = (''' + SELECT CASE WHEN pg_is_in_recovery() THEN + edgedb.raise( + NULL::bigint, + 'read_only_sql_transaction', + msg => 'cannot use a hot standby' + ) + END; + ''').strip() + else: + pg_is_in_recovery = '' + + return textwrap.dedent(f''' + {pg_is_in_recovery} + + {SETUP_TEMP_TABLE_SCRIPT} + {SETUP_CONFIG_CACHE_SCRIPT} + + {INIT_CON_SCRIPT_DATA} + + PREPARE _clear_state AS + WITH x1 AS ( + DELETE FROM _config_cache + ) + DELETE FROM _edgecon_state WHERE type = 'C' OR type = 'B'; + + PREPARE _apply_state(jsonb) AS + INSERT INTO + _edgecon_state(name, value, type) + SELECT + (CASE + WHEN e->'type' = '"B"'::jsonb + THEN edgedb._apply_session_config(e->>'name', e->'value') + ELSE e->>'name' + END) AS name, + e->'value' AS value, + e->>'type' AS type + FROM + jsonb_array_elements($1::jsonb) AS e; + + PREPARE _reset_session_config AS + SELECT edgedb._reset_session_config(); + + PREPARE _apply_sql_state(jsonb) AS + SELECT + e.key AS name, + pg_catalog.set_config(e.key, e.value, false) AS value + FROM + jsonb_each_text($1::jsonb) AS e; + ''').strip().encode('utf-8') + + +async def connect( + dsn_or_connection: str | ConnectionParams, + *, + backend_params: pg_params.BackendRuntimeParams, + source_description: str, + apply_init_script: bool = True, +) -> PGConnection: + global INIT_CON_SCRIPT + + pgcon = None + + if isinstance(dsn_or_connection, str): + connection = ConnectionParams(dsn=dsn_or_connection) + else: + connection = dsn_or_connection + try: + pgrawcon, pgcon = await create_postgres_connection( + connection, + lambda: PGConnection(dbname=connection.database), + source_description=source_description + ) + except pgerror.BackendConnectionError: + # If we don't mandate SSL, try again with no SSL + if (connection.sslmode == SSLMode.allow or + connection.sslmode == SSLMode.prefer): + connection = connection.clone() + connection.update(sslmode=SSLMode.disable) + pgrawcon, pgcon = await create_postgres_connection( + connection, + lambda: PGConnection(dbname=connection.database), + source_description=source_description + ) + else: + raise + + connection = pgrawcon.connection + pgcon.connection = pgrawcon.connection + pgcon.parameter_status = pgrawcon.state.parameters + cancellation_key = pgrawcon.state.cancellation_key + if cancellation_key: + pgcon.backend_pid = cancellation_key[0] + pgcon.backend_secret = cancellation_key[1] + pgcon.is_ssl = pgrawcon.state.ssl + pgcon.addr = pgrawcon.addr + + if ( + backend_params.has_create_role + and backend_params.session_authorization_role + ): + sup_role = backend_params.session_authorization_role + if connection.user != sup_role: + # We used to use SET SESSION AUTHORIZATION here, there're some + # security differences over SET ROLE, but as we don't allow + # accessing Postgres directly through EdgeDB, SET ROLE is mostly + # fine here. (Also hosted backends like Postgres on DigitalOcean + # support only SET ROLE) + await pgcon.sql_execute(f'SET ROLE {pg_qi(sup_role)}'.encode()) + + if 'in_hot_standby' in pgcon.parameter_status: + # in_hot_standby is always present in Postgres 14 and above + if pgcon.parameter_status['in_hot_standby'] == 'on': + # Abort if we're connecting to a hot standby + pgcon.terminate() + raise pgerror.BackendError(fields=dict( + M="cannot use a hot standby", + C=pgerror.ERROR_READ_ONLY_SQL_TRANSACTION, + )) + + if apply_init_script: + if INIT_CON_SCRIPT is None: + INIT_CON_SCRIPT = _build_init_con_script( + # On lower versions of Postgres we use pg_is_in_recovery() to + # check if it is a hot standby, and error out if it is. + check_pg_is_in_recovery=( + 'in_hot_standby' not in pgcon.parameter_status + ), + ) + try: + await pgcon.sql_execute(INIT_CON_SCRIPT) + except Exception: + logger.exception( + f"Failed to run init script for {pgcon.connection.to_dsn()}" + ) + await pgcon.close() + raise + + return pgcon + + +def set_init_con_script_data(cfg): + global INIT_CON_SCRIPT, INIT_CON_SCRIPT_DATA + INIT_CON_SCRIPT = None + INIT_CON_SCRIPT_DATA = (f''' + INSERT INTO _edgecon_state + SELECT * FROM jsonb_to_recordset({pg_ql(json.dumps(cfg))}::jsonb) + AS cfg(name text, value jsonb, type text); + ''').strip() diff --git a/edb/server/pgcon/pgcon.pxd b/edb/server/pgcon/pgcon.pxd index 434121c257e2..181c0f5422e1 100644 --- a/edb/server/pgcon/pgcon.pxd +++ b/edb/server/pgcon/pgcon.pxd @@ -34,8 +34,6 @@ from edb.server.pgproto.debug cimport PG_DEBUG from edb.server.cache cimport stmt_cache -include "scram.pxd" - cdef enum PGTransactionStatus: PQTRANS_IDLE = 0 # connection idle @@ -111,9 +109,9 @@ cdef class PGConnection: int32_t waiting_for_sync PGTransactionStatus xact_status - readonly int32_t backend_pid - readonly int32_t backend_secret - readonly object parameter_status + public int32_t backend_pid + public int32_t backend_secret + public object parameter_status readonly object aborted_with_error @@ -124,7 +122,8 @@ cdef class PGConnection: bint debug - object pgaddr + public object connection + public object addr object server object tenant bint is_system_db @@ -157,7 +156,6 @@ cdef class PGConnection: cdef write_sync(self, WriteBuffer outbuf) cdef make_clean_stmt_message(self, bytes stmt_name) - cdef make_auth_password_md5_message(self, bytes salt) cdef send_query_unit_group( self, object query_unit_group, bint sync, object bind_datas, bytes state, diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index c3f706c65804..71a311077081 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -31,8 +31,6 @@ import hashlib import json import logging import os.path -import socket -import ssl as ssl_mod import sys import struct import textwrap @@ -80,7 +78,6 @@ from edb.server.cache cimport stmt_cache from edb.server.dbview cimport dbview from edb.server.protocol cimport args_ser from edb.server.protocol cimport pg_ext -from edb.server import pgconnparams from edb.server import metrics from edb.server.protocol cimport frontend @@ -89,13 +86,8 @@ from edb.common import debug from . import errors as pgerror -include "scram.pyx" - DEF DATA_BUFFER_SIZE = 100_000 DEF PREP_STMTS_CACHE = 100 -DEF TCP_KEEPIDLE = 24 -DEF TCP_KEEPINTVL = 2 -DEF TCP_KEEPCNT = 3 DEF COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0" @@ -108,348 +100,10 @@ cdef dict POSTGRES_SHUTDOWN_ERR_CODES = { '57P02': 'crash_shutdown', } -cdef bytes INIT_CON_SCRIPT = None -cdef str INIT_CON_SCRIPT_DATA = '' cdef object EMPTY_SQL_STATE = b"{}" cdef object logger = logging.getLogger('edb.server') -# The '_edgecon_state table' is used to store information about -# the current session. The `type` column is one character, with one -# of the following values: -# -# * 'C': a session-level config setting -# -# * 'B': a session-level config setting that's implemented by setting -# a corresponding Postgres config setting. -# * 'A': an instance-level config setting from command-line arguments -# * 'E': an instance-level config setting from environment variable -SETUP_TEMP_TABLE_SCRIPT = ''' - CREATE TEMPORARY TABLE _edgecon_state ( - name text NOT NULL, - value jsonb NOT NULL, - type text NOT NULL CHECK( - type = 'C' OR type = 'B' OR type = 'A' OR type = 'E'), - UNIQUE(name, type) - ); -''' -SETUP_CONFIG_CACHE_SCRIPT = ''' - CREATE TEMPORARY TABLE _config_cache ( - source edgedb._sys_config_source_t, - value edgedb._sys_config_val_t NOT NULL - ); -''' - -def _build_init_con_script(*, check_pg_is_in_recovery: bool) -> bytes: - if check_pg_is_in_recovery: - pg_is_in_recovery = (''' - SELECT CASE WHEN pg_is_in_recovery() THEN - edgedb.raise( - NULL::bigint, - 'read_only_sql_transaction', - msg => 'cannot use a hot standby' - ) - END; - ''').strip() - else: - pg_is_in_recovery = '' - - return textwrap.dedent(f''' - {pg_is_in_recovery} - - {SETUP_TEMP_TABLE_SCRIPT} - {SETUP_CONFIG_CACHE_SCRIPT} - - {INIT_CON_SCRIPT_DATA} - - PREPARE _clear_state AS - WITH x1 AS ( - DELETE FROM _config_cache - ) - DELETE FROM _edgecon_state WHERE type = 'C' OR type = 'B'; - - PREPARE _apply_state(jsonb) AS - INSERT INTO - _edgecon_state(name, value, type) - SELECT - (CASE - WHEN e->'type' = '"B"'::jsonb - THEN edgedb._apply_session_config(e->>'name', e->'value') - ELSE e->>'name' - END) AS name, - e->'value' AS value, - e->>'type' AS type - FROM - jsonb_array_elements($1::jsonb) AS e; - - PREPARE _reset_session_config AS - SELECT edgedb._reset_session_config(); - - PREPARE _apply_sql_state(jsonb) AS - SELECT - e.key AS name, - pg_catalog.set_config(e.key, e.value, false) AS value - FROM - jsonb_each_text($1::jsonb) AS e; - ''').strip().encode('utf-8') - - -def _set_tcp_keepalive(transport): - # TCP keepalive was initially added here for special cases where idle - # connections are dropped silently on GitHub Action running test suite - # against AWS RDS. We are keeping the TCP keepalive for generic - # Postgres connections as the kernel overhead is considered low, and - # in certain cases it does save us some reconnection time. - # - # In case of high-availability Postgres, TCP keepalive is necessary to - # disconnect from a failing master node, if no other failover information - # is available. - sock = transport.get_extra_info('socket') - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - - # TCP_KEEPIDLE: the time (in seconds) the connection needs to remain idle - # before TCP starts sending keepalive probes. This is socket.TCP_KEEPIDLE - # on Linux, and socket.TCP_KEEPALIVE on macOS from Python 3.10. - if hasattr(socket, 'TCP_KEEPIDLE'): - sock.setsockopt(socket.IPPROTO_TCP, - socket.TCP_KEEPIDLE, TCP_KEEPIDLE) - if hasattr(socket, 'TCP_KEEPALIVE'): - sock.setsockopt(socket.IPPROTO_TCP, - socket.TCP_KEEPALIVE, TCP_KEEPIDLE) - - # TCP_KEEPINTVL: The time (in seconds) between individual keepalive probes. - if hasattr(socket, 'TCP_KEEPINTVL'): - sock.setsockopt(socket.IPPROTO_TCP, - socket.TCP_KEEPINTVL, TCP_KEEPINTVL) - - # TCP_KEEPCNT: The maximum number of keepalive probes TCP should send - # before dropping the connection. - if hasattr(socket, 'TCP_KEEPCNT'): - sock.setsockopt(socket.IPPROTO_TCP, - socket.TCP_KEEPCNT, TCP_KEEPCNT) - - -async def _create_ssl_connection(protocol_factory, host, port, *, - loop, ssl_context, ssl_is_advisory, - connect_timeout): - async with asyncio.timeout(connect_timeout): - tr, pr = await loop.create_connection( - lambda: TLSUpgradeProto(loop, host, port, - ssl_context, ssl_is_advisory), - host, port) - _set_tcp_keepalive(tr) - - tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message. - - try: - do_ssl_upgrade = await pr.on_data - except (Exception, asyncio.CancelledError): - tr.close() - raise - - if do_ssl_upgrade: - try: - new_tr = await loop.start_tls( - tr, pr, ssl_context, server_hostname=host) - except (Exception, asyncio.CancelledError): - tr.close() - raise - else: - new_tr = tr - - pg_proto = protocol_factory() - pg_proto.is_ssl = do_ssl_upgrade - pg_proto.connection_made(new_tr) - new_tr.set_protocol(pg_proto) - - return new_tr, pg_proto - - -class _RetryConnectSignal(Exception): - pass - - -async def _connect(connargs, dbname, ssl): - - loop = asyncio.get_running_loop() - - host = connargs.get("host") - port = connargs.get("port") - sslmode = connargs.get('sslmode', pgconnparams.SSLMode.prefer) - timeout = connargs.get('connect_timeout') - - try: - if host.startswith('/'): - addr = os.path.join(host, f'.s.PGSQL.{port}') - async with asyncio.timeout(timeout): - _, pgcon = await loop.create_unix_connection( - lambda: PGConnection(dbname, loop, connargs), addr) - - else: - if ssl: - _, pgcon = await _create_ssl_connection( - lambda: PGConnection(dbname, loop, connargs), - host, - port, - loop=loop, - ssl_context=ssl, - ssl_is_advisory=( - sslmode == pgconnparams.SSLMode.prefer - ), - connect_timeout=timeout, - ) - else: - async with asyncio.timeout(timeout): - trans, pgcon = await loop.create_connection( - lambda: PGConnection(dbname, loop, connargs), - host=host, port=port) - _set_tcp_keepalive(trans) - except TimeoutError as ex: - raise pgerror.new( - pgerror.ERROR_CONNECTION_FAILURE, - "timed out connecting to backend", - ) from ex - - try: - await pgcon.connect() - except pgerror.BackendError as e: - pgcon.terminate() - if not e.code_is(pgerror.ERROR_INVALID_AUTHORIZATION_SPECIFICATION): - raise - - if ( - sslmode == pgconnparams.SSLMode.allow and not pgcon.is_ssl or - sslmode == pgconnparams.SSLMode.prefer and pgcon.is_ssl - ): - # Trigger retry when: - # 1. First attempt with sslmode=allow, ssl=None failed - # 2. First attempt with sslmode=prefer, ssl=ctx failed while the - # server claimed to support SSL (returning "S" for SSLRequest) - # (likely because pg_hba.conf rejected the connection) - raise _RetryConnectSignal() - - else: - # but will NOT retry if: - # 1. First attempt with sslmode=prefer failed but the server - # doesn't support SSL (returning 'N' for SSLRequest), because - # we already tried to connect without SSL thru ssl_is_advisory - # 2. Second attempt with sslmode=prefer, ssl=None failed - # 3. Second attempt with sslmode=allow, ssl=ctx failed - # 4. Any other sslmode - raise - - return pgcon - - -async def connect( - connargs: Dict[str, Any], - dbname: str, - backend_params: pg_params.BackendRuntimeParams, - apply_init_script: bool = True, -): - global INIT_CON_SCRIPT - - # This is different than parsing DSN and use the default sslmode=prefer, - # because connargs can be set manually thru set_connection_params(), and - # the caller should be responsible for aligning sslmode with ssl. - sslmode = connargs.get('sslmode', pgconnparams.SSLMode.disable) - ssl = connargs.get('ssl') - if sslmode == pgconnparams.SSLMode.allow: - try: - pgcon = await _connect(connargs, dbname, ssl=None) - except _RetryConnectSignal: - pgcon = await _connect(connargs, dbname, ssl=ssl) - elif sslmode == pgconnparams.SSLMode.prefer: - try: - pgcon = await _connect(connargs, dbname, ssl=ssl) - except _RetryConnectSignal: - pgcon = await _connect(connargs, dbname, ssl=None) - else: - pgcon = await _connect(connargs, dbname, ssl=ssl) - - if ( - backend_params.has_create_role - and backend_params.session_authorization_role - ): - sup_role = backend_params.session_authorization_role - if connargs['user'] != sup_role: - # We used to use SET SESSION AUTHORIZATION here, there're some - # security differences over SET ROLE, but as we don't allow - # accessing Postgres directly through EdgeDB, SET ROLE is mostly - # fine here. (Also hosted backends like Postgres on DigitalOcean - # support only SET ROLE) - await pgcon.sql_execute(f'SET ROLE {pg_qi(sup_role)}'.encode()) - - if 'in_hot_standby' in pgcon.parameter_status: - # in_hot_standby is always present in Postgres 14 and above - if pgcon.parameter_status['in_hot_standby'] == 'on': - # Abort if we're connecting to a hot standby - pgcon.terminate() - raise pgerror.BackendError(fields=dict( - M="cannot use a hot standby", - C=pgerror.ERROR_READ_ONLY_SQL_TRANSACTION, - )) - - if apply_init_script: - if INIT_CON_SCRIPT is None: - INIT_CON_SCRIPT = _build_init_con_script( - # On lower versions of Postgres we use pg_is_in_recovery() to - # check if it is a hot standby, and error out if it is. - check_pg_is_in_recovery=( - 'in_hot_standby' not in pgcon.parameter_status - ), - ) - try: - await pgcon.sql_execute(INIT_CON_SCRIPT) - except Exception: - await pgcon.close() - raise - - return pgcon - - -def set_init_con_script_data(cfg): - global INIT_CON_SCRIPT, INIT_CON_SCRIPT_DATA - INIT_CON_SCRIPT = None - INIT_CON_SCRIPT_DATA = (f''' - INSERT INTO _edgecon_state - SELECT * FROM jsonb_to_recordset({pg_ql(json.dumps(cfg))}::jsonb) - AS cfg(name text, value jsonb, type text); - ''').strip() - - -class TLSUpgradeProto(asyncio.Protocol): - def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): - self.on_data = loop.create_future() - self.host = host - self.port = port - self.ssl_context = ssl_context - self.ssl_is_advisory = ssl_is_advisory - - def data_received(self, data): - if data == b'S': - self.on_data.set_result(True) - elif (self.ssl_is_advisory and - self.ssl_context.verify_mode == ssl_mod.CERT_NONE and - data == b'N'): - # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE, - # since the only way to get ssl_is_advisory is from - # sslmode=prefer. But be extra sure to disallow insecure - # connections when the ssl context asks for real security. - self.on_data.set_result(False) - else: - self.on_data.set_exception( - ConnectionError( - 'PostgreSQL server at "{host}:{port}" ' - 'rejected SSL upgrade'.format( - host=self.host, port=self.port))) - - def connection_lost(self, exc): - if not self.on_data.done(): - if exc is None: - exc = ConnectionError('unexpected connection_lost() call') - self.on_data.set_exception(exc) - @cython.final cdef class EdegDBCodecContext(pgproto.CodecContext): @@ -561,18 +215,19 @@ cdef class PGMessage: @cython.final cdef class PGConnection: - def __init__(self, dbname, loop, addr): + def __init__(self, dbname): self.buffer = ReadBuffer() - self.loop = loop + self.loop = asyncio.get_running_loop() self.dbname = dbname + self.connection = None self.transport = None self.msg_waiter = None self.prep_stmts = stmt_cache.StatementsCache(maxsize=PREP_STMTS_CACHE) - self.connected_fut = loop.create_future() + self.connected_fut = self.loop.create_future() self.connected = False self.waiting_for_sync = 0 @@ -587,7 +242,6 @@ cdef class PGConnection: self.log_listeners = [] - self.pgaddr = addr self.server = None self.tenant = None self.is_system_db = False @@ -626,9 +280,6 @@ cdef class PGConnection: file=sys.stderr, ) - def get_pgaddr(self): - return self.pgaddr - def in_tx(self): return ( self.xact_status == PQTRANS_INTRANS or @@ -2813,97 +2464,6 @@ cdef class PGConnection: finally: await self.after_command() - async def connect(self): - cdef: - WriteBuffer outbuf - WriteBuffer buf - char mtype - int32_t status - - if self.connected_fut is not None: - await self.connected_fut - if self.connected: - raise RuntimeError('already connected') - if self.transport is None: - raise RuntimeError('no transport object in connect()') - - buf = WriteBuffer() - - # protocol version - buf.write_int16(3) - buf.write_int16(0) - - for k, v in self.pgaddr['server_settings'].items(): - buf.write_bytestring(k.encode('utf-8')) - buf.write_bytestring(v.encode('utf-8')) - - buf.write_bytestring(b'user') - buf.write_bytestring(self.pgaddr['user'].encode('utf-8')) - - buf.write_bytestring(b'database') - buf.write_bytestring(self.dbname.encode('utf-8')) - - buf.write_bytestring(b'') - - # Send the buffer - outbuf = WriteBuffer() - outbuf.write_int32(buf.len() + 4) - outbuf.write_buffer(buf) - self.write(outbuf) - - # Need this to handle first ReadyForQuery - self.waiting_for_sync += 1 - - while True: - if not self.buffer.take_message(): - await self.wait_for_message() - mtype = self.buffer.get_message_type() - - try: - if mtype == b'R': - # Authentication... - status = self.buffer.read_int32() - if status == PGAUTH_SUCCESSFUL: - pass - elif status == PGAUTH_REQUIRED_PASSWORDMD5: - # Note: MD5 salt is passed as a four-byte sequence - md5_salt = self.buffer.read_bytes(4) - self.write( - self.make_auth_password_md5_message(md5_salt)) - - elif status == PGAUTH_REQUIRED_SASL: - await self._auth_sasl() - - else: - raise RuntimeError(f'unsupported auth method: {status}') - - elif mtype == b'K': - # BackendKeyData - self.backend_pid = self.buffer.read_int32() - self.backend_secret = self.buffer.read_int32() - - elif mtype == b'E': - # ErrorResponse - er_cls, er_fields = self.parse_error_message() - raise er_cls(fields=er_fields) - - elif mtype == b'Z': - # ReadyForQuery - self.parse_sync_message() - self.connected = True - break - - elif mtype == b'S': - # ParameterStatus - name, value = self.parse_parameter_status_message() - self.parameter_status[name] = value - - else: - self.fallthrough() - - finally: - self.buffer.finish_message() - def is_healthy(self): return ( self.connected and @@ -3160,114 +2720,6 @@ cdef class PGConnection: buf.write_bytestring(stmt_name) return buf.end_message() - cdef make_auth_password_md5_message(self, bytes salt): - cdef WriteBuffer msg - - msg = WriteBuffer.new_message(b'p') - - user = self.pgaddr.get('user') or '' - password = self.pgaddr.get('password') or '' - - # 'md5' + md5(md5(password + username) + salt)) - userpass = (password + user).encode('ascii') - hash = hashlib.md5(hashlib.md5(userpass).hexdigest().\ - encode('ascii') + salt).hexdigest().encode('ascii') - - msg.write_bytestring(b'md5' + hash) - return msg.end_message() - - async def _auth_sasl(self): - methods = [] - auth_method = self.buffer.read_null_str() - while auth_method: - methods.append(auth_method) - auth_method = self.buffer.read_null_str() - self.buffer.finish_message() - - if not methods: - raise RuntimeError( - 'the backend requested SASL authentication but did not ' - 'offer any methods') - - for method in methods: - if method in SCRAMAuthentication.AUTHENTICATION_METHODS: - break - else: - raise RuntimeError( - f'the backend offered the following SASL authentication ' - f'methods: {b", ".join(methods).decode()}, neither are ' - f'supported.' - ) - - user = self.pgaddr.get('user') or '' - password = self.pgaddr.get('password') or '' - scram = SCRAMAuthentication(method) - - msg = WriteBuffer.new_message(b'p') - msg.write_bytes(scram.create_client_first_message(user)) - msg.end_message() - self.write(msg) - - while True: - if not self.buffer.take_message(): - await self.wait_for_message() - mtype = self.buffer.get_message_type() - - if mtype == b'E': - # ErrorResponse - er_cls, er_fields = self.parse_error_message() - raise er_cls(fields=er_fields) - - elif mtype == b'R': - # Authentication... - break - - else: - self.fallthrough() - - status = self.buffer.read_int32() - if status != PGAUTH_SASL_CONTINUE: - raise RuntimeError( - f'expected SASLContinue from the server, received {status}') - - server_response = self.buffer.consume_message() - scram.parse_server_first_message(server_response) - msg = WriteBuffer.new_message(b'p') - client_final_message = scram.create_client_final_message(password) - msg.write_bytes(client_final_message) - msg.end_message() - - self.write(msg) - - while True: - if not self.buffer.take_message(): - await self.wait_for_message() - mtype = self.buffer.get_message_type() - - if mtype == b'E': - # ErrorResponse - er_cls, er_fields = self.parse_error_message() - raise er_cls(fields=er_fields) - - elif mtype == b'R': - # Authentication... - break - - else: - self.fallthrough() - - status = self.buffer.read_int32() - if status != PGAUTH_SASL_FINAL: - raise RuntimeError( - f'expected SASLFinal from the server, received {status}') - - server_response = self.buffer.consume_message() - if not scram.verify_server_final_message(server_response): - raise pgerror.BackendError(fields=dict( - M="server SCRAM proof does not match", - C=pgerror.ERROR_INVALID_PASSWORD, - )) - async def wait_for_message(self): if self.buffer.take_message(): return @@ -3280,6 +2732,7 @@ cdef class PGConnection: if self.transport is not None: raise RuntimeError('connection_made: invalid connection status') self.transport = transport + self.connected = True self.connected_fut.set_result(True) self.connected_fut = None diff --git a/edb/server/pgcon/rust_transport.py b/edb/server/pgcon/rust_transport.py new file mode 100644 index 000000000000..ba50de93d0ec --- /dev/null +++ b/edb/server/pgcon/rust_transport.py @@ -0,0 +1,473 @@ +""" +This module implements a Rust-based transport for PostgreSQL connections. + +The PGRawConn class provides a high-level interface for establishing and +managing PostgreSQL connections using a Rust-implemented state machine. It +handles the complexities of connection establishment, including SSL negotiation +and authentication, while presenting a simple asyncio-like transport interface +to the caller. +""" + +import asyncio +import ssl as ssl_module +import socket +import warnings +from typing import Optional, List, Tuple, Protocol, Callable, Dict, Any, TypeVar +from enum import Enum, auto +from edb.server._pg_rust import PyConnectionState +from dataclasses import dataclass +from edb.server.pgconnparams import ( + ConnectionParams, + SSLMode, + get_pg_home_directory +) +from . import errors as pgerror + +TCP_KEEPIDLE = 24 +TCP_KEEPINTVL = 2 +TCP_KEEPCNT = 3 + + +class ConnectionStateType(Enum): + CONNECTING = 0 + SSL_CONNECTING = auto() + AUTHENTICATING = auto() + SYNCHRONIZING = auto() + READY = auto() + + +class Authentication(Enum): + NONE = 0 + PASSWORD = auto() + MD5 = auto() + SCRAM_SHA256 = auto() + + +@dataclass +class PGState: + parameters: Dict[str, str] + cancellation_key: Optional[Tuple[int, int]] + auth: Optional[Authentication] + server_error: Optional[list[tuple[str, str]]] + ssl: bool + + +class ConnectionStateUpdate(Protocol): + def send(self, message: memoryview) -> None: ... + def upgrade(self) -> None: ... + def parameter(self, name: str, value: str) -> None: ... + def cancellation_key(self, pid: int, key: int) -> None: ... + def state_changed(self, state: int) -> None: ... + def auth(self, auth: int) -> None: ... + + +StateChangeCallback = Callable[[ConnectionStateType], None] + + +def _parse_tls_version(tls_version: str) -> ssl_module.TLSVersion: + if tls_version.startswith('SSL'): + raise ValueError( + f"Unsupported TLS version: {tls_version}" + ) + try: + return ssl_module.TLSVersion[tls_version.replace('.', '_')] + except KeyError: + raise ValueError( + f"No such TLS version: {tls_version}" + ) + + +def _create_ssl(ssl_config: Dict[str, Any]): + sslmode = SSLMode.parse(ssl_config['sslmode']) + ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) + ssl.check_hostname = sslmode >= SSLMode.verify_full + if sslmode < SSLMode.require: + ssl.verify_mode = ssl_module.CERT_NONE + else: + if ssl_config['sslrootcert']: + ssl.load_verify_locations(ssl_config['sslrootcert']) + ssl.verify_mode = ssl_module.CERT_REQUIRED + else: + if sslmode == SSLMode.require: + ssl.verify_mode = ssl_module.CERT_NONE + if ssl_config['sslcrl']: + ssl.load_verify_locations(ssl_config['sslcrl']) + ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN + if ssl_config['sslkey'] and ssl_config['sslcert']: + ssl.load_cert_chain(ssl_config['sslcert'], + ssl_config['sslkey'], + ssl_config['sslpassword'] or '') + if ssl_config['ssl_max_protocol_version']: + ssl.maximum_version = _parse_tls_version( + ssl_config['ssl_max_protocol_version']) + if ssl_config['ssl_min_protocol_version']: + ssl.minimum_version = _parse_tls_version( + ssl_config['ssl_min_protocol_version']) + # OpenSSL 1.1.1 keylog file + if hasattr(ssl, 'keylog_filename'): + if ssl_config['keylog_filename']: + ssl.keylog_filename = ssl_config['keylog_filename'] + return ssl + + +class PGConnectionProtocol(asyncio.Protocol): + """A protocol that manages the initial connection and authentication process + for PostgreSQL. + + This protocol acts as an intermediary between the raw socket connection and + the user's protocol. + """ + def __init__( + self, + state: PyConnectionState, + protocol: asyncio.Protocol, + ready_future: asyncio.Future, + pg_state: PGState, + ): + self.state = state + self.pg_state = pg_state + self.protocol = protocol + self.ready_future = ready_future + self._writing_paused = False + self._ready = False + + def data_received(self, data: bytes): + if self._ready: + self.protocol.data_received(data) + else: + try: + self.state.drive_message(memoryview(data)) + except Exception as e: + if error := self.pg_state.server_error: + self.pg_state.server_error = None + self.ready_future.set_exception(pgerror.BackendConnectionError(fields=dict(error))) + else: + self.ready_future.set_exception(ConnectionError(e)) + if self.state.is_ready(): + self._ready = True + self.ready_future.set_result(True) + + def connection_lost(self, exc): + if self._ready: + self.protocol.connection_lost(exc) + else: + if not self.ready_future.done: + if exc: + self.ready_future.set_exception(exc) + else: + self.ready_future.set_exception(RuntimeError( + "Connection unexpectedly lost" + )) + + def pause_writing(self): + self._writing_paused = True + if self._ready: + self.protocol.pause_writing() + + def resume_writing(self): + self._writing_paused = False + if self._ready: + self.protocol.resume_writing() + + def is_ready(self): + return self._ready + + +class PGRawConn(asyncio.Transport): + def __init__(self, + source_description: Optional[str], + connection: ConnectionParams, + raw_transport: asyncio.Transport, + pg_state: PGState, + addr: tuple[str, int]): + super().__init__() + self.source_description = source_description + self.connection = connection + self.raw_transport = raw_transport + self._pg_state = pg_state + self.addr = addr + + @property + def state(self): + return self._pg_state + + def write(self, data: bytes | bytearray | memoryview): + self.raw_transport.write(data) + + def close(self): + self.raw_transport.close() + + def is_closing(self): + return self.raw_transport.is_closing() + + def get_extra_info(self, name: str, default=None): + return self.raw_transport.get_extra_info(name, default) + + def pause_reading(self): + self.raw_transport.pause_reading() + + def resume_reading(self): + self.raw_transport.resume_reading() + + def is_reading(self): + return self.raw_transport.is_reading() + + def abort(self): + self.raw_transport.abort() + + def __repr__(self): + params = ', '.join(f"{k}={v}" for k, v in + self._pg_state.parameters.items()) + auth_str = f", auth={self._pg_state.auth.name}" \ + if self._pg_state.auth else "" + source_str = f", source={self.source_description}" \ + if self.source_description else "" + raw_repr = repr(self.raw_transport) + dsn = self.connection._params + return (f"") + + def __del__(self): + if not self.raw_transport.is_closing(): + warnings.warn( + f"unclosed transport {repr(self)}", + ResourceWarning, + stacklevel=2 + ) + + +class RustTransportUpdate(ConnectionStateUpdate): + raw_transport: asyncio.Transport + state: PyConnectionState + state_change_callback: Optional[StateChangeCallback] + + def __init__( + self, + state: PyConnectionState, + raw_transport: asyncio.Transport, + state_change_callback: Optional[StateChangeCallback], + pg_state: PGState, + host: Optional[str], + ready_future: asyncio.Future, + ): + self.state = state + self._pg_state = pg_state + self.raw_transport = raw_transport + self._state_change_callback = state_change_callback + self._host = host + self._ready_future = ready_future + + def send(self, message: memoryview) -> None: + self.raw_transport.write(bytes(message)) + + def upgrade(self) -> None: + asyncio.create_task(self._upgrade_to_ssl()) + + async def _upgrade_to_ssl(self): + try: + config = self.state.config + ssl_context = _create_ssl(config) + loop = asyncio.get_running_loop() + new_transport = await loop.start_tls( + self.raw_transport, + self.raw_transport.get_protocol(), + ssl_context, + server_side=False, + ssl_handshake_timeout=None, + server_hostname=self._host + ) + self.raw_transport = new_transport + self.state.drive_ssl_ready() + self._pg_state.ssl = True + except Exception as e: + self._ready_future.set_exception(e) + raise + + def parameter(self, name: str, value: str) -> None: + self._pg_state.parameters[name] = value + + def cancellation_key(self, pid: int, key: int) -> None: + self._pg_state.cancellation_key = (pid, key) + + def state_changed(self, state: int) -> None: + if self._state_change_callback is not None: + self._state_change_callback(ConnectionStateType(state)) + + def auth(self, auth: int) -> None: + self._pg_state.auth = Authentication(auth) + + def server_error(self, error: list[tuple[str, str]]) -> None: + self._pg_state.server_error = error + + +async def _create_connection_to( + protocol_factory: Callable[[], asyncio.Protocol], + protocol: str, + host: str, + port: int + ) -> Tuple[str, str, int, asyncio.Transport]: + if protocol == "unix": + t, _ = await asyncio.get_running_loop().create_unix_connection( + protocol_factory, + f"{host}/.s.PGSQL.{port}" + ) + return (protocol, host, port, t) + else: + t, _ = await asyncio.get_running_loop().create_connection( + protocol_factory, + host, port + ) + _set_tcp_keepalive(t) + return (protocol, host, port, t) + +async def _create_connection( + protocol_factory: Callable[[], asyncio.Protocol], + connect_timeout: Optional[int], + host_candidates: List[Tuple[str, str, int]] + ) -> Tuple[str, str, int, asyncio.Transport]: + e = None + for protocol, host, port in host_candidates: + async with asyncio.timeout(connect_timeout if connect_timeout else 60): + try: + return await _create_connection_to( + protocol_factory, + protocol, + host, + port + ) + except asyncio.CancelledError as ex: + raise pgerror.new( + pgerror.ERROR_CONNECTION_FAILURE, + "timed out connecting to backend", + ) from ex + except Exception as ex: + e = ex + continue + raise ConnectionError(f"Failed to connect to any of " + f"the provided hosts: {host_candidates}") from e + + +def _set_tcp_keepalive(transport): + # TCP keepalive was initially added here for special cases where idle + # connections are dropped silently on GitHub Action running test suite + # against AWS RDS. We are keeping the TCP keepalive for generic + # Postgres connections as the kernel overhead is considered low, and + # in certain cases it does save us some reconnection time. + # + # In case of high-availability Postgres, TCP keepalive is necessary to + # disconnect from a failing master node, if no other failover information + # is available. + sock = transport.get_extra_info('socket') + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + + # TCP_KEEPIDLE: the time (in seconds) the connection needs to remain idle + # before TCP starts sending keepalive probes. This is socket.TCP_KEEPIDLE + # on Linux, and socket.TCP_KEEPALIVE on macOS from Python 3.10. + if hasattr(socket, 'TCP_KEEPIDLE'): + sock.setsockopt(socket.IPPROTO_TCP, + socket.TCP_KEEPIDLE, TCP_KEEPIDLE) + if hasattr(socket, 'TCP_KEEPALIVE'): + sock.setsockopt(socket.IPPROTO_TCP, + socket.TCP_KEEPALIVE, TCP_KEEPIDLE) + + # TCP_KEEPINTVL: The time (in seconds) between individual keepalive probes. + if hasattr(socket, 'TCP_KEEPINTVL'): + sock.setsockopt(socket.IPPROTO_TCP, + socket.TCP_KEEPINTVL, TCP_KEEPINTVL) + + # TCP_KEEPCNT: The maximum number of keepalive probes TCP should send + # before dropping the connection. + if hasattr(socket, 'TCP_KEEPCNT'): + sock.setsockopt(socket.IPPROTO_TCP, + socket.TCP_KEEPCNT, TCP_KEEPCNT) + + +P = TypeVar('P', bound=asyncio.Protocol) + + +async def create_postgres_connection( + dsn: str | ConnectionParams, + protocol_factory: Callable[[], P], + *, + source_description: Optional[str] = None, + state_change_callback: Optional[StateChangeCallback] = None +) -> Tuple[PGRawConn, P]: + """ + Open a PostgreSQL connection to the address specified by the DSN. + + The DSN (Data Source Name) should include connection details like host, + port, database, etc. + + protocol_factory must be a callable returning an asyncio protocol + implementation. + + This method establishes the connection asynchronously. When successful, it + returns a (PGRawConn, protocol) pair. + + :param dsn: Data Source Name for the PostgreSQL connection + :param protocol_factory: Callable that returns an asyncio protocol + :param state_change_callback: Optional callback for connection state changes + :return: Tuple of PGRawConn and asyncio.Protocol + """ + if isinstance(dsn, str): + dsn = ConnectionParams(dsn=dsn) + connect_timeout = dsn.connect_timeout + try: + state = PyConnectionState( + dsn._params, + "postgres", + str(get_pg_home_directory()) + ) + except Exception as e: + raise ValueError(e) + ready_future: asyncio.Future = asyncio.Future() + pg_state = PGState( + parameters={}, + cancellation_key=None, + auth=None, + server_error=None, + ssl=False + ) + + user_protocol = protocol_factory() + protocol, host, port, raw_transport = await _create_connection( + lambda: PGConnectionProtocol( + state, + user_protocol, + ready_future, + pg_state + ), + connect_timeout, + state.config.hosts + ) + + try: + update = RustTransportUpdate( + state, + raw_transport, + state_change_callback, + pg_state, + host if protocol == "tcp" else None, + ready_future + ) + state.update = update + state.drive_initial() + + await ready_future + raw_transport = update.raw_transport + conn = PGRawConn( + source_description, + ConnectionParams._create(state.config), + raw_transport, + pg_state, + (host, port) + ) + raw_transport.set_protocol(user_protocol) + + # Notify the user protocol of successful connection + user_protocol.connection_made(conn) + except: + raw_transport.abort() + raise + + return conn, user_protocol diff --git a/edb/server/pgcon/scram.pxd b/edb/server/pgcon/scram.pxd deleted file mode 100644 index d7e096766e66..000000000000 --- a/edb/server/pgcon/scram.pxd +++ /dev/null @@ -1,46 +0,0 @@ -# This file is copied from: -# https://github.com/MagicStack/asyncpg/blob/383c711e/asyncpg/protocol/scram.pxd -# -# Copyright (C) 2021-present MagicStack Inc. and the EdgeDB authors. -# Copyright (C) 2016-present the asyncpg authors and contributors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -cdef class SCRAMAuthentication: - cdef: - readonly bytes authentication_method - readonly bytes authorization_message - readonly bytes client_channel_binding - readonly bytes client_first_message_bare - readonly bytes client_nonce - readonly bytes client_proof - readonly bytes password_salt - readonly int password_iterations - readonly bytes server_first_message - # server_key is an instance of hmac.HAMC - readonly object server_key - readonly bytes server_nonce - - cdef create_client_first_message(self, str username) - cdef create_client_final_message(self, str password) - cdef parse_server_first_message(self, bytes server_response) - cdef verify_server_final_message(self, bytes server_final_message) - cdef _bytes_xor(self, bytes a, bytes b) - cdef _generate_client_nonce(self, int num_bytes) - cdef _generate_client_proof(self, str password) - cdef _generate_salted_password( - self, str password, bytes salt, int iterations - ) - cdef _normalize_password(self, str original_password) diff --git a/edb/server/pgcon/scram.pyx b/edb/server/pgcon/scram.pyx deleted file mode 100644 index f39771755d07..000000000000 --- a/edb/server/pgcon/scram.pyx +++ /dev/null @@ -1,370 +0,0 @@ -# This file is copied from: -# https://github.com/MagicStack/asyncpg/blob/383c711e/asyncpg/protocol/scram.pyx -# -# Copyright (C) 2021-present MagicStack Inc. and the EdgeDB authors. -# Copyright (C) 2016-present the asyncpg authors and contributors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import base64 -import hashlib -import hmac -import re -import stringprep -import unicodedata - - -# try to import the secrets library from Python 3.6+ for the -# cryptographic token generator for generating nonces as part of SCRAM -# Otherwise fall back on os.urandom -try: - from secrets import token_bytes as generate_token_bytes -except ImportError: - from os import urandom as generate_token_bytes - -@cython.final -cdef class SCRAMAuthentication: - """Contains the protocol for generating and a SCRAM hashed password. - - Since PostgreSQL 10, the option to hash passwords using the SCRAM-SHA-256 - method was added. This module follows the defined protocol, which can be - referenced from here: - - https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256 - - libpq references the following RFCs that it uses for implementation: - - * RFC 5802 - * RFC 5803 - * RFC 7677 - - The protocol works as such: - - - A client connects to the server. The server requests the client to begin - SASL authentication using SCRAM and presents a client with the methods it - supports. At present, those are SCRAM-SHA-256, and, on servers that are - built with OpenSSL and - are PG11+, SCRAM-SHA-256-PLUS (which supports channel binding, more on that - below) - - - The client sends a "first message" to the server, where it chooses which - method to authenticate with, and sends, along with the method, an - indication of channel binding (we disable for now), a nonce, and the - username. (Technically, PostgreSQL ignores the username as it already has - it from the initial connection, but we add it for completeness) - - - The server responds with a "first message" in which it extends the nonce, - as well as a password salt and the number of iterations to hash the - password with. The client validates that the new nonce contains the first - part of the client's original nonce - - - The client generates a salted password, but does not sent this up to the - server. Instead, the client follows the SCRAM algorithm (RFC5802) to - generate a proof. This proof is sent as part of a client "final message" to - the server for it to validate. - - - The server validates the proof. If it is valid, the server sends a - verification code for the client to verify that the server came to the same - proof the client did. PostgreSQL immediately sends an AuthenticationOK - response right after a valid negotiation. If the password the client - provided was invalid, then authentication fails. - - (The beauty of this is that the salted password is never transmitted over - the wire!) - - PostgreSQL 11 added support for the channel binding (i.e. - SCRAM-SHA-256-PLUS) but to do some ongoing discussion, there is a conscious - decision by several driver authors to not support it as of yet. As such, - the channel binding parameter is hard-coded to "n" for now, but can be - updated to support other channel binding methods in the future. - """ - AUTHENTICATION_METHODS = [b"SCRAM-SHA-256"] - DEFAULT_CLIENT_NONCE_BYTES = 24 - DIGEST = hashlib.sha256 - REQUIREMENTS_CLIENT_FINAL_MESSAGE = ['client_channel_binding', - 'server_nonce'] - REQUIREMENTS_CLIENT_PROOF = ['password_iterations', 'password_salt', - 'server_first_message', 'server_nonce'] - SASLPREP_PROHIBITED = ( - stringprep.in_table_a1, # PostgreSQL treats this as prohibited - stringprep.in_table_c12, - stringprep.in_table_c21_c22, - stringprep.in_table_c3, - stringprep.in_table_c4, - stringprep.in_table_c5, - stringprep.in_table_c6, - stringprep.in_table_c7, - stringprep.in_table_c8, - stringprep.in_table_c9, - ) - - def __cinit__(self, bytes authentication_method): - self.authentication_method = authentication_method - self.authorization_message = None - # channel binding is turned off for the time being - self.client_channel_binding = b"n,," - self.client_first_message_bare = None - self.client_nonce = None - self.client_proof = None - self.password_salt = None - # self.password_iterations = None - self.server_first_message = None - self.server_key = None - self.server_nonce = None - - cdef create_client_first_message(self, str username): - """Create the initial client message for SCRAM authentication""" - cdef: - bytes msg - bytes client_first_message - - self.client_nonce = \ - self._generate_client_nonce(self.DEFAULT_CLIENT_NONCE_BYTES) - # set the client first message bare here, as it's used in a later step - self.client_first_message_bare = b"n=" + username.encode("utf-8") + \ - b",r=" + self.client_nonce - # put together the full message here - msg = bytes() - msg += self.authentication_method + b"\0" - client_first_message = self.client_channel_binding + \ - self.client_first_message_bare - msg += (len(client_first_message)).to_bytes(4, byteorder='big') + \ - client_first_message - return msg - - cdef create_client_final_message(self, str password): - """Create the final client message as part of SCRAM authentication""" - cdef: - bytes msg - - if any([getattr(self, val) is None for val in - self.REQUIREMENTS_CLIENT_FINAL_MESSAGE]): - raise RuntimeError( - "you need values from server to generate a client proof") - - # normalize the password using the SASLprep algorithm in RFC 4013 - password = self._normalize_password(password) - - # generate the client proof - self.client_proof = self._generate_client_proof(password=password) - msg = bytes() - msg += b"c=" + base64.b64encode(self.client_channel_binding) + \ - b",r=" + self.server_nonce + \ - b",p=" + base64.b64encode(self.client_proof) - return msg - - cdef parse_server_first_message(self, bytes server_response): - """Parse the response from the first message from the server""" - self.server_first_message = server_response - try: - self.server_nonce = re.search(b'r=([^,]+),', - self.server_first_message).group(1) - except IndexError: - raise RuntimeError("could not get nonce") - if not self.server_nonce.startswith(self.client_nonce): - raise pgerror.BackendError(fields=dict( - M="server SCRAM nonce does not match", - C=pgerror.ERROR_INVALID_PASSWORD, - )) - try: - self.password_salt = re.search(b's=([^,]+),', - self.server_first_message).group(1) - except IndexError: - raise RuntimeError("could not get salt") - try: - self.password_iterations = int(re.search(b'i=(\d+),?', - self.server_first_message).group(1)) - except (IndexError, TypeError, ValueError): - raise RuntimeError("could not get iterations") - - cdef verify_server_final_message(self, bytes server_final_message): - """Verify the final message from the server""" - cdef: - bytes server_signature - - try: - server_signature = re.search(b'v=([^,]+)', - server_final_message).group(1) - except IndexError: - raise RuntimeError("could not get server signature") - - verify_server_signature = hmac.new(self.server_key.digest(), - self.authorization_message, self.DIGEST) - # validate the server signature against the verifier - return server_signature == base64.b64encode( - verify_server_signature.digest()) - - cdef _bytes_xor(self, bytes a, bytes b): - """XOR two bytestrings together""" - return bytes(a_i ^ b_i for a_i, b_i in zip(a, b)) - - cdef _generate_client_nonce(self, int num_bytes): - cdef: - bytes token - - token = generate_token_bytes(num_bytes) - - return base64.b64encode(token) - - cdef _generate_client_proof(self, str password): - """need to ensure a server response exists, i.e. """ - cdef: - bytes salted_password - - if any([getattr(self, val) is None for val in - self.REQUIREMENTS_CLIENT_PROOF]): - raise RuntimeError( - "you need values from server to generate a client proof") - # generate a salt password - salted_password = self._generate_salted_password(password, - self.password_salt, self.password_iterations) - # client key is derived from the salted password - client_key = hmac.new(salted_password, b"Client Key", self.DIGEST) - # this allows us to compute the stored key that is residing on the - # server - stored_key = self.DIGEST(client_key.digest()) - # as well as compute the server key - self.server_key = hmac.new(salted_password, b"Server Key", self.DIGEST) - # build the authorization message that will be used in the - # client signature - # the "c=" portion is for the channel binding, but this is not - # presently implemented - self.authorization_message = self.client_first_message_bare + b"," + \ - self.server_first_message + b",c=" + \ - base64.b64encode(self.client_channel_binding) + \ - b",r=" + self.server_nonce - # sign! - client_signature = hmac.new(stored_key.digest(), - self.authorization_message, self.DIGEST) - # and the proof - return self._bytes_xor(client_key.digest(), client_signature.digest()) - - cdef _generate_salted_password( - self, str password, bytes salt, int iterations - ): - """This follows the "Hi" algorithm specified in RFC5802""" - cdef: - bytes p - bytes s - bytes u - - # convert the password to a binary string - UTF8 is safe for SASL - # (though there are SASLPrep rules) - p = password.encode("utf8") - # the salt needs to be base64 decoded -- full binary must be used - s = base64.b64decode(salt) - # the initial signature is the salt with a terminator of a 32-bit - # string ending in 1 - ui = hmac.new(p, s + b'\x00\x00\x00\x01', self.DIGEST) - # grab the initial digest - u = ui.digest() - # for X number of iterations, recompute the HMAC signature against the - # password and the latest iteration of the hash, and XOR it with the - # previous version - for x in range(iterations - 1): - ui = hmac.new(p, ui.digest(), hashlib.sha256) - # this is a fancy way of XORing two byte strings together - u = self._bytes_xor(u, ui.digest()) - return u - - cdef _normalize_password(self, str original_password): - """Normalize the password using the SASLprep from RFC4013""" - cdef: - str normalized_password - - # Note: Per the PostgreSQL documentation, PostgreSWL does not require - # UTF-8 to be used for the password, but will perform SASLprep on the - # password regardless. - # If the password is not valid UTF-8, PostgreSQL will then **not** use - # SASLprep processing. - # If the password fails SASLprep, the password should still be sent - # See: https://www.postgresql.org/docs/current/sasl-authentication.html - # and - # https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/common/saslprep.c - # using the `pg_saslprep` function - normalized_password = original_password - # if the original password is an ASCII string or fails to encode as a - # UTF-8 string, then no further action is needed - try: - original_password.encode("ascii") - except UnicodeEncodeError: - pass - else: - return original_password - - # Step 1 of SASLPrep: Map. Per the algorithm, we map non-ascii space - # characters to ASCII spaces (\x20 or \u0020, but we will use ' ') and - # commonly mapped to nothing characters are removed - # Table C.1.2 -- non-ASCII spaces - # Table B.1 -- "Commonly mapped to nothing" - normalized_password = u"".join( - ' ' if stringprep.in_table_c12(c) else c - for c in tuple(normalized_password) - if not stringprep.in_table_b1(c) - ) - - # If at this point the password is empty, PostgreSQL uses the original - # password - if not normalized_password: - return original_password - - # Step 2 of SASLPrep: Normalize. Normalize the password using the - # Unicode normalization algorithm to NFKC form - normalized_password = unicodedata.normalize( - 'NFKC', normalized_password - ) - - # If the password is not empty, PostgreSQL uses the original password - if not normalized_password: - return original_password - - normalized_password_tuple = tuple(normalized_password) - - # Step 3 of SASLPrep: Prohibited characters. If PostgreSQL detects any - # of the prohibited characters in SASLPrep, it will use the original - # password - # We also include "unassigned code points" in the prohibited character - # category as PostgreSQL does the same - for c in normalized_password_tuple: - if any( - in_prohibited_table(c) - for in_prohibited_table in self.SASLPREP_PROHIBITED - ): - return original_password - - # Step 4 of SASLPrep: Bi-directional characters. PostgreSQL follows the - # rules for bi-directional characters laid on in RFC3454 Sec. 6 which - # are: - # 1. Characters in RFC 3454 Sec 5.8 are prohibited (C.8) - # 2. If a string contains a RandALCat character, it cannot contain any - # LCat character - # 3. If the string contains any RandALCat character, an RandALCat - # character must be the first and last character of the string - # RandALCat characters are found in table D.1, whereas LCat are in D.2 - if any(stringprep.in_table_d1(c) for c in normalized_password_tuple): - # if the first character or the last character are not in D.1, - # return the original password - if not (stringprep.in_table_d1(normalized_password_tuple[0]) and - stringprep.in_table_d1(normalized_password_tuple[-1])): - return original_password - - # if any characters are in D.2, use the original password - if any( - stringprep.in_table_d2(c) for c in normalized_password_tuple - ): - return original_password - - # return the normalized password - return normalized_password diff --git a/edb/server/pgconnparams.py b/edb/server/pgconnparams.py index 03df7ba25bd1..1d86a5b19094 100644 --- a/edb/server/pgconnparams.py +++ b/edb/server/pgconnparams.py @@ -12,52 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from __future__ import annotations -from typing import Optional, Tuple, Union, Dict, List - -import dataclasses -import enum -import getpass +# import pathlib import platform -import ssl as ssl_module +import ssl import warnings -import edb.server._pg_rust - - -class SSLMode(enum.IntEnum): - disable = 0 - allow = 1 - prefer = 2 - require = 3 - verify_ca = 4 - verify_full = 5 - - @classmethod - def parse(cls, sslmode: Union[SSLMode, str]) -> SSLMode: - if isinstance(sslmode, SSLMode): - rv = sslmode - else: - rv = getattr(cls, sslmode.replace('-', '_')) - return rv - - -@dataclasses.dataclass -class ConnectionParameters: - user: str - password: Optional[str] = None - database: Optional[str] = None - ssl: Optional[ssl_module.SSLContext] = None - sslmode: Optional[SSLMode] = None - server_settings: Dict[str, str] = dataclasses.field(default_factory=dict) - connect_timeout: Optional[int] = None - +from typing import TypedDict, NotRequired, Optional, Unpack, Self, Any +from enum import IntEnum +from edb.server._pg_rust import PyConnectionParams _system = platform.uname().system - - if _system == 'Windows': import ctypes.wintypes @@ -81,110 +45,141 @@ def get_pg_home_directory() -> pathlib.Path: return pathlib.Path.home() / '.postgresql' -def _parse_tls_version(tls_version: str) -> ssl_module.TLSVersion: - if tls_version.startswith('SSL'): - raise ValueError( - f"Unsupported TLS version: {tls_version}" - ) - try: - return ssl_module.TLSVersion[tls_version.replace('.', '_')] - except KeyError: - raise ValueError( - f"No such TLS version: {tls_version}" - ) - +class SSLMode(IntEnum): + disable = 0 + allow = 1 + prefer = 2 + require = 3 + verify_ca = 4 + verify_full = 5 -def parse_dsn( - dsn: str, -) -> Tuple[ - Tuple[Tuple[str, int], ...], - ConnectionParameters, -]: - try: - parsed, ssl_paths = edb.server._pg_rust.parse_dsn(getpass.getuser(), - str(get_pg_home_directory()), - dsn) - except Exception as e: - raise ValueError(f"{e.args[0]}") from e - - # Extract SSL configuration from the dict - ssl = None - sslmode = SSLMode.disable - ssl_config = parsed['ssl'] - if 'Enable' in ssl_config: - ssl_config = ssl_config['Enable'] - ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) - sslmode = SSLMode.parse(ssl_config[0].lower()) - ssl.check_hostname = sslmode >= SSLMode.verify_full - ssl_config = ssl_config[1] - if sslmode < SSLMode.require: - ssl.verify_mode = ssl_module.CERT_NONE + @classmethod + def parse(cls, sslmode: str) -> Self: + value: Self = getattr(cls, sslmode.replace('-', '_')) + assert value is not None, f"Invalid SSL mode: {sslmode}" + return value + + +class CreateParamsKwargs(TypedDict, total=False): + dsn: NotRequired[str] + hosts: NotRequired[Optional[list[tuple[str, int]]]] + host: NotRequired[Optional[str]] + user: NotRequired[Optional[str]] + database: NotRequired[Optional[str]] + server_settings: NotRequired[Optional[dict[str, str]]] + sslmode: NotRequired[Optional[SSLMode]] + sslrootcert: NotRequired[Optional[str]] + connect_timeout: NotRequired[Optional[int]] + + +# This is a Python representation of the Rust connection parameters that are +# passed back during connection/parse. +class ConnectionParams: + _params: PyConnectionParams + + def __init__(self, **kwargs: Unpack[CreateParamsKwargs]) -> None: + dsn = kwargs.pop("dsn", None) + if dsn: + self._params = PyConnectionParams(dsn) else: - if ssl_paths['rootcert']: - ssl.load_verify_locations(ssl_paths['rootcert']) - ssl.verify_mode = ssl_module.CERT_REQUIRED - else: - if sslmode == SSLMode.require: - ssl.verify_mode = ssl_module.CERT_NONE - if ssl_paths['crl']: - ssl.load_verify_locations(ssl_paths['crl']) - ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN - if ssl_paths['key'] and ssl_paths['cert']: - ssl.load_cert_chain(ssl_paths['cert'], - ssl_paths['key'], - ssl_config['password'] or '') - if ssl_config['max_protocol_version']: - ssl.maximum_version = _parse_tls_version( - ssl_config['max_protocol_version']) - if ssl_config['min_protocol_version']: - ssl.minimum_version = _parse_tls_version( - ssl_config['min_protocol_version']) - # OpenSSL 1.1.1 keylog file - if hasattr(ssl, 'keylog_filename'): - if ssl_config['keylog_filename']: - ssl.keylog_filename = ssl_config['keylog_filename'] - - # Extract hosts from the dict - addrs: List[Tuple[str, int]] = [] - for host in parsed['hosts']: - if 'Hostname' in host: - host, port = host['Hostname'] - addrs.append((host, port)) - elif 'IP' in host: - ip, port, scope = host['IP'] - # Reconstruct the scope ID - if scope: - ip = f'{ip}%{scope}' - addrs.append((ip, port)) - elif 'Path' in host: - path, port = host['Path'] - addrs.append((path, port)) - elif 'Abstract' in host: - path, port = host['Abstract'] - addrs.append((path, port)) - - # Database/user/password/connect_timeout - database: str = str(parsed['database']) or '' - user: str = str(parsed['user']) or '' - connect_timeout = parsed['connect_timeout']['secs'] \ - if parsed['connect_timeout'] else None - - # Extract password from the dict - password: str | None = "" - password_config = parsed['password'] - if 'Unspecified' in password_config: - password = '' - elif 'Specified' in password_config: - password = password_config['Specified'] - - params = ConnectionParameters( - user=user, - password=password, - database=database, - ssl=ssl, - sslmode=sslmode, - server_settings=parsed['server_settings'], - connect_timeout=connect_timeout, - ) - - return tuple(addrs), params + self._params = PyConnectionParams(None) + self.update(**kwargs) + + @classmethod + def _create( + cls, + params: dict[str, Any], + ssl: Optional[ssl.SSLContext] = None + ) -> Self: + instance = super().__new__(cls) + instance._params = params + return instance + + def update(self, **kwargs: Unpack[CreateParamsKwargs]) -> None: + if dsn := kwargs.pop('dsn', None): + params = PyConnectionParams(dsn) + for k, v in params.to_dict().items(): + self._params[k] = v + if server_settings := kwargs.pop("server_settings", None): + for k2, v2 in server_settings.items(): + self._params.update_server_settings(k2, v2) + if host_specs := kwargs.pop("hosts", None): + hosts, ports = zip(*host_specs) + self._params['host'] = ','.join(hosts) + self._params['port'] = ','.join(map(str, ports)) + if (ssl_mode := kwargs.pop("sslmode", None)) is not None: + mode: SSLMode = ssl_mode + self._params["sslmode"] = mode.name + if connect_timeout := kwargs.pop("connect_timeout", None): + self._params["connect_timeout"] = str(connect_timeout) + for k, v in kwargs.items(): + if k == "database": + k = "dbname" + self._params[k] = v + + def resolve(self) -> Self: + return self._create( + self._params.resolve("", str(get_pg_home_directory())), + ) + + def clone(self) -> Self: + return self._create(self._params.clone()) + + @property + def hosts(self) -> Optional[list[tuple[dict[str, Any], int]]]: + return self._params['hosts'] # type: ignore + + @property + def host(self) -> Optional[str]: + return self._params['host'] # type: ignore + + @property + def port(self) -> Optional[int]: + return self._params['port'] # type: ignore + + @property + def user(self) -> Optional[str]: + return self._params['user'] # type: ignore + + @property + def password(self) -> Optional[str]: + return self._params['password'] # type: ignore + + @property + def database(self) -> Optional[str]: + return self._params['dbname'] # type: ignore + + @property + def connect_timeout(self) -> Optional[int]: + connect_timeout = self._params['connect_timeout'] + return int(connect_timeout) if connect_timeout else None + + @property + def sslmode(self) -> Optional[SSLMode]: + sslmode = self._params['sslmode'] + return SSLMode.parse(sslmode) if sslmode is not None else None + + def to_dsn(self) -> str: + dsn: str = self._params.to_dsn() + return dsn + + @property + def __dict__(self) -> dict[str, Any]: + to_dict: dict[str, str] = self._params.to_dict() + database = to_dict.pop('dbname', None) + if database: + to_dict['database'] = database + return to_dict + + @__dict__.setter + def __dict__(self, value: dict[str, Any]) -> None: + new_params = self._params.__class__() + try: + for k, v in value.items(): + new_params[k] = v + self._params = new_params + except Exception: + raise ValueError("Failed to update __dict__") + + def __repr__(self) -> Any: + return self._params.__repr__() diff --git a/edb/server/pgrust/Cargo.toml b/edb/server/pgrust/Cargo.toml index e0deccd23422..00a7a181e167 100644 --- a/edb/server/pgrust/Cargo.toml +++ b/edb/server/pgrust/Cargo.toml @@ -13,9 +13,11 @@ python_extension = ["pyo3/extension-module", "pyo3/serde"] optimizer = [] [dependencies] +pyo3.workspace = true +tokio.workspace = true + futures = "0" scopeguard = "1" -pyo3 = "0" itertools = "0" thiserror = "1" tracing = "0" @@ -23,13 +25,16 @@ tracing-subscriber = "0" strum = { version = "0.26", features = ["derive"] } consume_on_drop = "0" smart-default = "0" -openssl = "0.10.66" +openssl = { version = "0.10.66", features = ["v111"] } +tokio-openssl = "0.6.4" paste = "1" unicode-normalization = "0.1.23" stringprep = "0.1.5" hmac = "0.12" base64 = "0.22" sha2 = "0.10" +hex = "0.4.3" +md5 = "0.7.0" rand = "0" hexdump = "0" url = "2" @@ -37,10 +42,7 @@ serde = "1" serde_derive = "1" serde-pickle = "1" percent-encoding = "2" - -[dependencies.tokio] -version = "1" -features = ["rt", "time", "sync", "net", "io-util"] +roaring = "0.10.6" [dependencies.derive_more] version = "1.0.0-beta.6" @@ -57,6 +59,9 @@ byteorder = "1.5" clap = "4" clap_derive = "4" hex-literal = "0.4" +tempfile = "3" +socket2 = "0.5.7" +libc = "0.2.158" [dev-dependencies.tokio] version = "1" diff --git a/edb/server/pgrust/__init__.pyi b/edb/server/pgrust/__init__.pyi new file mode 100644 index 000000000000..b5de90fc6b65 --- /dev/null +++ b/edb/server/pgrust/__init__.pyi @@ -0,0 +1,2 @@ +def parse_dsn(home_dir: str, user: str, url: str) -> dict: + pass diff --git a/edb/server/pgrust/examples/connect.rs b/edb/server/pgrust/examples/connect.rs new file mode 100644 index 000000000000..24f5ffa06940 --- /dev/null +++ b/edb/server/pgrust/examples/connect.rs @@ -0,0 +1,176 @@ +use clap::Parser; +use clap_derive::Parser; +use openssl::ssl::{Ssl, SslContext, SslMethod}; +use pgrust::{ + connection::{ + parse_postgres_dsn_env, tokio::TokioSocketAddress, Client, Credentials, Host, HostType, + }, + protocol::{DataRow, ErrorResponse, RowDescription}, +}; +use std::net::{SocketAddr, ToSocketAddrs}; +use tokio::task::LocalSet; + +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(short = 'D', long = "dsn", value_parser, conflicts_with_all = &["unix", "tcp", "username", "password", "database"])] + dsn: Option, + + /// Network socket address and port + #[clap(short = 't', long = "tcp", value_parser, conflicts_with = "unix")] + tcp: Option, + + /// Unix socket path + #[clap(short = 'u', long = "unix", value_parser, conflicts_with = "tcp")] + unix: Option, + + /// Username to use for the connection + #[clap( + short = 'U', + long = "username", + value_parser, + default_value = "postgres" + )] + username: String, + + /// Username to use for the connection + #[clap(short = 'P', long = "password", value_parser, default_value = "")] + password: String, + + /// Database to use for the connection + #[clap( + short = 'd', + long = "database", + value_parser, + default_value = "postgres" + )] + database: String, + + /// SQL statements to run + #[clap( + name = "statements", + trailing_var_arg = true, + allow_hyphen_values = true, + help = "Zero or more SQL statements to run (defaults to 'select 1')" + )] + statements: Option>, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + let mut args = Args::parse(); + eprintln!("{args:?}"); + + let mut socket_address: Option = None; + if let Some(dsn) = args.dsn { + let mut conn = parse_postgres_dsn_env(&dsn, std::env::vars())?; + #[allow(deprecated)] + let home = std::env::home_dir().unwrap(); + conn.password + .resolve(&home, &conn.hosts, &conn.database, &conn.database)?; + args.database = conn.database; + args.username = conn.user; + args.password = conn.password.password().unwrap_or_default().to_string(); + let host = conn.hosts.first().unwrap(); + eprintln!("Connecting to {host:?}"); + match host { + Host(HostType::Path(path), port) => { + socket_address = Some(TokioSocketAddress::new_unix(format!( + "{path}/.s.PGSQL.{port}" + ))); + } + Host(HostType::Hostname(hostname), port) => { + let addr = format!("{hostname}:{port}") + .to_socket_addrs()? + .next() + .unwrap(); + socket_address = Some(TokioSocketAddress::new_tcp(addr)); + } + Host(HostType::IP(ip, interface), port) => { + let addr = if let Some(interface) = interface { + format!("{ip}%{interface}") + } else { + ip.to_string() + }; + socket_address = Some(TokioSocketAddress::new_tcp( + format!("{addr}:{port}").parse()?, + )); + } + Host(HostType::Abstract(..), _) => { + unimplemented!("Abstract socket connection not yet supported") + } + } + } + + let socket_address = socket_address.unwrap_or_else(|| match (args.tcp, args.unix) { + (Some(addr), None) => TokioSocketAddress::new_tcp(addr), + (None, Some(path)) => TokioSocketAddress::new_unix(path), + _ => panic!("Must specify either a TCP address or a Unix socket path"), + }); + + let credentials = Credentials { + username: args.username, + password: args.password, + database: args.database, + server_settings: Default::default(), + }; + + let statements = args + .statements + .unwrap_or_else(|| vec!["select 1;".to_string()]); + let local = LocalSet::new(); + local + .run_until(run_queries(socket_address, credentials, statements)) + .await?; + + Ok(()) +} + +async fn run_queries( + socket_address: TokioSocketAddress, + credentials: Credentials, + statements: Vec, +) -> Result<(), Box> { + let client = socket_address.connect().await?; + let ssl = SslContext::builder(SslMethod::tls_client())?.build(); + let ssl = Ssl::new(&ssl)?; + + let (conn, task) = Client::new(credentials, client, ssl); + tokio::task::spawn_local(task); + conn.ready().await?; + + let local = LocalSet::new(); + eprintln!("Statements: {statements:?}"); + for statement in statements { + let sink = ( + |rows: RowDescription<'_>| { + eprintln!("\nFields:"); + for field in rows.fields() { + eprint!(" {:?}", field.name()); + } + eprintln!(); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |row: Result, ErrorResponse<'_>>| { + let _ = &guard; + if let Ok(row) = row { + eprintln!("Row:"); + for field in row.values() { + eprint!(" {:?}", field); + } + eprintln!(); + } + } + }, + |error: ErrorResponse<'_>| { + eprintln!("\nError:\n {:?}", error); + }, + ); + local.spawn_local(conn.query(&statement, sink)); + } + local.await; + + Ok(()) +} diff --git a/edb/server/pgrust/examples/dsn.rs b/edb/server/pgrust/examples/dsn.rs index f63afe0ee125..3080dcc5ff19 100644 --- a/edb/server/pgrust/examples/dsn.rs +++ b/edb/server/pgrust/examples/dsn.rs @@ -1,9 +1,9 @@ -use pgrust::parse_postgres_url; +use pgrust::connection::parse_postgres_dsn_env; fn main() { let dsn = std::env::args().nth(1).expect("No DSN provided"); - let mut params = parse_postgres_url(&dsn, std::env::vars()).unwrap(); + let mut params = parse_postgres_dsn_env(&dsn, std::env::vars()).unwrap(); #[allow(deprecated)] let home = std::env::home_dir().unwrap(); eprintln!("DSN: {dsn}\n----\n{:#?}\n", params); @@ -15,6 +15,6 @@ fn main() { "Resolved password:\n------------------\n{:#?}\n", params.password ); - let ssl = params.ssl.resolve(&home).unwrap(); - eprintln!("Resolved SSL:\n-------------\n{:#?}\n", ssl); + params.ssl.resolve(&home).unwrap(); + eprintln!("Resolved SSL:\n-------------\n{:#?}\n", ()); } diff --git a/edb/server/pgrust/src/auth/md5.rs b/edb/server/pgrust/src/auth/md5.rs new file mode 100644 index 000000000000..0562b35d690d --- /dev/null +++ b/edb/server/pgrust/src/auth/md5.rs @@ -0,0 +1,50 @@ +/// Computes the MD5 password hash used in PostgreSQL authentication. +/// +/// This function implements the MD5 password hashing algorithm as specified in the +/// PostgreSQL documentation for MD5 authentication. +/// +/// # Algorithm +/// +/// 1. Concatenate the password and username. +/// 2. Calculate the MD5 hash of this concatenated string. +/// 3. Concatenate the hexadecimal representation of the hash from step 2 with the salt. +/// 4. Calculate the MD5 hash of the result from step 3. +/// 5. Return the final hash as hex, prefixed with "md5". +/// +/// # Example +/// +/// ``` +/// let password = "secret"; +/// let username = "user"; +/// let salt = [0x01, 0x02, 0x03, 0x04]; +/// let hash = pgrust::auth::md5_password(password, username, &salt); +/// assert_eq!(hash, "md5fccef98e4f1cf6cbe96b743fad4e8bd0"); +/// ``` +pub fn md5_password(password: &str, username: &str, salt: &[u8; 4]) -> String { + // First MD5 hash of password + username + let mut hasher = md5::Context::new(); + hasher.consume(password.as_bytes()); + hasher.consume(username.as_bytes()); + let first_hash = hasher.compute(); + + // Convert first hash to hex string + let first_hash_hex = to_hex_string(&first_hash.0); + + // Second MD5 hash of first hash + salt + let mut hasher = md5::Context::new(); + hasher.consume(first_hash_hex.as_bytes()); + hasher.consume(salt); + let second_hash = hasher.compute(); + + // Combine 'md5' prefix with final hash + format!("md5{}", to_hex_string(&second_hash.0)) +} + +/// Converts a byte slice to a hexadecimal string. +fn to_hex_string(bytes: &[u8]) -> String { + let mut hex = String::with_capacity(bytes.len() * 2); + for &byte in bytes { + hex.push_str(&format!("{:02x}", byte)); + } + hex +} diff --git a/edb/server/pgrust/src/auth/mod.rs b/edb/server/pgrust/src/auth/mod.rs new file mode 100644 index 000000000000..ae3c8891e3dc --- /dev/null +++ b/edb/server/pgrust/src/auth/mod.rs @@ -0,0 +1,11 @@ +mod md5; +mod scram; +mod stringprep; +mod stringprep_table; + +pub use md5::md5_password; +pub use scram::{ + generate_salted_password, generate_stored_key, ClientEnvironment, ClientTransaction, + SCRAMError, Sha256Out, +}; +pub use stringprep::{sasl_normalize_password, sasl_normalize_password_bytes}; diff --git a/edb/server/pgrust/src/auth/scram.rs b/edb/server/pgrust/src/auth/scram.rs new file mode 100644 index 000000000000..e60ca734ee18 --- /dev/null +++ b/edb/server/pgrust/src/auth/scram.rs @@ -0,0 +1,815 @@ +//! # SCRAM (Salted Challenge Response Authentication Mechanism) +//! +//! # Transaction +//! +//! The transaction consists of four steps: +//! +//! 1. **Client's Initial Response**: The client sends its username and initial nonce. +//! 2. **Server's Challenge**: The server responds with a combined nonce, a base64-encoded salt, and an iteration count for the PBKDF2 algorithm. +//! 3. **Client's Proof**: The client sends its proof of possession of the password, along with the combined nonce and base64-encoded channel binding data. +//! 4. **Server's Final Response**: The server sends its verifier, proving successful authentication. +//! +//! This transaction securely authenticates the client to the server without transmitting the actual password. +//! +//! # Parameters +//! +//! The following parameters are used in the SCRAM authentication exchange: +//! +//! * `r=` (nonce): A random string generated by the client and server to ensure the uniqueness of each authentication exchange. +//! The client initially sends its nonce, and the server responds with a combined nonce (client’s nonce + server’s nonce). +//! +//! * `c=` (channel binding): A base64-encoded representation of the channel binding data. +//! This parameter is used to bind the authentication to the specific channel over which it is occurring, ensuring the integrity of the communication channel. +//! +//! * `s=` (salt): A base64-encoded salt provided by the server. +//! The salt is used in conjunction with the client’s password to generate a salted password for enhanced security. +//! +//! * `i=` (iteration count): The number of iterations to apply in the PBKDF2 (Password-Based Key Derivation Function 2) algorithm. +//! This parameter defines the computational cost of generating the salted password. +//! +//! * `n=` (name): The username of the client. +//! This parameter is included in the client’s initial response. +//! +//! * `p=` (proof): The client’s proof of possession of the password. +//! This is a base64-encoded value calculated using the salted password and other SCRAM parameters to prove that the client knows the password without sending it directly. +//! +//! * `v=` (verifier): The server’s verifier, which is used to prove that the server also knows the shared secret. +//! This parameter is included in the server’s final message to confirm successful authentication. +#![allow(unused)] + +use base64::{prelude::BASE64_STANDARD, Engine}; +use hmac::{Hmac, Mac}; +use sha2::{digest::FixedOutput, Digest, Sha256}; +use std::borrow::Cow; + +use super::sasl_normalize_password_bytes; + +const CHANNEL_BINDING_ENCODED: &str = "biws"; +const MINIMUM_NONCE_LENGTH: usize = 16; + +type HmacSha256 = Hmac; +pub type Sha256Out = [u8; 32]; + +#[derive(Debug, thiserror::Error)] +pub enum SCRAMError { + #[error("Invalid encoding")] + ProtocolError, +} + +pub trait ServerEnvironment { + fn get_password_parameters(&self, username: &str) -> (Cow<'static, str>, usize); + fn get_salted_password(&self, username: &str) -> Sha256Out; + fn generate_nonce(&self) -> String; +} + +#[derive(Default)] +pub struct ServerTransaction { + state: ServerState, +} + +impl ServerTransaction { + pub fn success(&self) -> bool { + matches!(self.state, ServerState::Success) + } + + pub fn process_message( + &mut self, + message: &[u8], + env: &impl ServerEnvironment, + ) -> Result>, SCRAMError> { + match &self.state { + ServerState::Success => Err(SCRAMError::ProtocolError), + ServerState::Initial => { + let message = ClientFirstMessage::decode(message)?; + if message.channel_binding != ChannelBinding::NotSupported("".into()) { + return Err(SCRAMError::ProtocolError); + } + if message.nonce.len() < MINIMUM_NONCE_LENGTH { + return Err(SCRAMError::ProtocolError); + } + let (salt, iterations) = env.get_password_parameters(&message.username); + let mut nonce = message.nonce.to_string(); + nonce += &env.generate_nonce(); + let response = ServerFirstResponse { + combined_nonce: nonce.to_string().into(), + salt, + iterations, + }; + self.state = + ServerState::SentChallenge(message.to_owned_bare(), response.to_owned()); + Ok(Some(response.encode().into_bytes())) + } + ServerState::SentChallenge(first_message, first_response) => { + let message = ClientFinalMessage::decode(message)?; + if message.combined_nonce != first_response.combined_nonce { + return Err(SCRAMError::ProtocolError); + } + if message.channel_binding != CHANNEL_BINDING_ENCODED { + return Err(SCRAMError::ProtocolError); + } + let salted_password = env.get_salted_password(&first_message.username); + let (client_proof, server_verifier) = generate_proof( + first_message.encode().as_bytes(), + first_response.encode().as_bytes(), + message.channel_binding.as_bytes(), + message.combined_nonce.as_bytes(), + &salted_password, + ); + let mut proof = vec![]; + BASE64_STANDARD + .decode_vec(message.proof.as_bytes(), &mut proof) + .map_err(|_| SCRAMError::ProtocolError)?; + if proof != client_proof { + return Err(SCRAMError::ProtocolError); + } + self.state = ServerState::Success; + let verifier = BASE64_STANDARD.encode(server_verifier).into(); + Ok(Some(ServerFinalResponse { verifier }.encode().into_bytes())) + } + } + } +} + +#[derive(Default)] +enum ServerState { + #[default] + Initial, + SentChallenge(ClientFirstMessage<'static>, ServerFirstResponse<'static>), + Success, +} + +pub trait ClientEnvironment { + fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out; + fn generate_nonce(&self) -> String; +} + +#[derive(Debug)] +pub struct ClientTransaction { + state: ClientState, +} + +impl ClientTransaction { + pub fn new(username: Cow<'static, str>) -> Self { + Self { + state: ClientState::Initial(username), + } + } + + pub fn success(&self) -> bool { + matches!(self.state, ClientState::Success) + } + + pub fn process_message( + &mut self, + message: &[u8], + env: &impl ClientEnvironment, + ) -> Result>, SCRAMError> { + match &self.state { + ClientState::Success => Err(SCRAMError::ProtocolError), + ClientState::Initial(username) => { + if !message.is_empty() { + return Err(SCRAMError::ProtocolError); + } + let nonce = env.generate_nonce().into(); + let message = ClientFirstMessage { + channel_binding: ChannelBinding::NotSupported("".into()), + username: username.clone(), + nonce, + }; + self.state = ClientState::SentFirst(message.to_owned_bare()); + Ok(Some(message.encode().into_bytes())) + } + ClientState::SentFirst(first_message) => { + let message = ServerFirstResponse::decode(message)?; + // Ensure the client nonce was concatenated with the server's nonce + if !message + .combined_nonce + .starts_with(first_message.nonce.as_ref()) + { + return Err(SCRAMError::ProtocolError); + } + if message.combined_nonce.len() - first_message.nonce.len() < MINIMUM_NONCE_LENGTH { + return Err(SCRAMError::ProtocolError); + } + let mut buffer = [0; 1024]; + let salt = decode_salt(&message.salt, &mut buffer)?; + let salted_password = env.get_salted_password(&salt, message.iterations); + let (client_proof, server_verifier) = generate_proof( + first_message.encode().as_bytes(), + message.encode().as_bytes(), + CHANNEL_BINDING_ENCODED.as_bytes(), + message.combined_nonce.as_bytes(), + &salted_password, + ); + let message = ClientFinalMessage { + channel_binding: CHANNEL_BINDING_ENCODED.into(), + combined_nonce: message.combined_nonce.to_string().into(), + proof: BASE64_STANDARD.encode(client_proof).into(), + }; + self.state = ClientState::ExpectingVerifier(ServerFinalResponse { + verifier: BASE64_STANDARD.encode(server_verifier).into(), + }); + Ok(Some(message.encode().into_bytes())) + } + ClientState::ExpectingVerifier(server_final_response) => { + let message = ServerFinalResponse::decode(message)?; + if message.verifier != server_final_response.verifier { + return Err(SCRAMError::ProtocolError); + } + self.state = ClientState::Success; + Ok(None) + } + } + } +} + +#[derive(Debug)] +enum ClientState { + Initial(Cow<'static, str>), + SentFirst(ClientFirstMessage<'static>), + ExpectingVerifier(ServerFinalResponse<'static>), + Success, +} + +trait Encode { + fn encode(&self) -> String; +} + +trait Decode<'a> { + fn decode(buf: &'a [u8]) -> Result + where + Self: Sized + 'a; +} + +fn extract<'a>(input: &'a [u8], prefix: &'static str) -> Result<&'a str, SCRAMError> { + let bytes = input + .strip_prefix(prefix.as_bytes()) + .ok_or(SCRAMError::ProtocolError)?; + std::str::from_utf8(bytes).map_err(|_| SCRAMError::ProtocolError) +} + +fn inext<'a>(it: &mut impl Iterator) -> Result<&'a [u8], SCRAMError> { + it.next().ok_or(SCRAMError::ProtocolError) +} + +fn hmac(s: &[u8]) -> HmacSha256 { + // This is effectively infallible + HmacSha256::new_from_slice(s).expect("HMAC can take key of any size") +} + +#[derive(Debug, Clone, PartialEq, Eq)] +/// `gs2-cbind-flag` from RFC5802. +enum ChannelBinding<'a> { + /// No channel binding + NotSpecified, + /// "n" -> client doesn't support channel binding. + NotSupported(Cow<'a, str>), + /// "y" -> client does support channel binding but thinks the server does + /// not. + Supported(Cow<'a, str>), + /// "p" -> client requires channel binding. The selected channel binding + /// follows "p=". + Required(Cow<'a, str>, Cow<'a, str>), +} + +#[derive(Debug)] +pub struct ClientFirstMessage<'a> { + channel_binding: ChannelBinding<'a>, + username: Cow<'a, str>, + nonce: Cow<'a, str>, +} + +impl ClientFirstMessage<'_> { + /// Get the bare first message + pub fn to_owned_bare(&self) -> ClientFirstMessage<'static> { + ClientFirstMessage { + channel_binding: ChannelBinding::NotSpecified, + username: self.username.to_string().into(), + nonce: self.nonce.to_string().into(), + } + } +} + +impl Encode for ClientFirstMessage<'_> { + fn encode(&self) -> String { + let channel_binding = match self.channel_binding { + ChannelBinding::NotSpecified => "".to_string(), + ChannelBinding::NotSupported(ref s) => format!("n,{},", s), + ChannelBinding::Supported(ref s) => format!("y,{},", s), + ChannelBinding::Required(ref s, ref t) => format!("p={},{},", t, s), + }; + format!("{channel_binding}n={},r={}", self.username, self.nonce) + } +} + +impl<'a> Decode<'a> for ClientFirstMessage<'a> { + fn decode(buf: &'a [u8]) -> Result { + let mut parts = buf.split(|&b| b == b','); + + // Check for channel binding + let mut next = inext(&mut parts)?; + let mut channel_binding = ChannelBinding::NotSpecified; + match (next.len(), next.first()) { + (_, Some(b'p')) => { + // p=(cb-name),(authz-id), + let Some(cb_name) = next.strip_prefix(b"p=") else { + return Err(SCRAMError::ProtocolError); + }; + let cb_name = + std::str::from_utf8(cb_name).map_err(|_| SCRAMError::ProtocolError)?; + let param = inext(&mut parts)?; + channel_binding = ChannelBinding::Required( + Cow::Borrowed( + std::str::from_utf8(param).map_err(|_| SCRAMError::ProtocolError)?, + ), + cb_name.into(), + ); + next = inext(&mut parts)?; + } + (1, Some(b'n')) => { + let param = inext(&mut parts)?; + channel_binding = ChannelBinding::NotSupported(Cow::Borrowed( + std::str::from_utf8(param).map_err(|_| SCRAMError::ProtocolError)?, + )); + next = inext(&mut parts)?; + } + (1, Some(b'y')) => { + let param = inext(&mut parts)?; + channel_binding = ChannelBinding::Supported(Cow::Borrowed( + std::str::from_utf8(param).map_err(|_| SCRAMError::ProtocolError)?, + )); + next = inext(&mut parts)?; + } + (_, None) => { + return Err(SCRAMError::ProtocolError); + } + _ => { + // No channel binding specified + } + } + let username = extract(next, "n=")?.into(); + let nonce = extract(inext(&mut parts)?, "r=")?.into(); + Ok(ClientFirstMessage { + channel_binding, + username, + nonce, + }) + } +} + +pub struct ServerFirstResponse<'a> { + combined_nonce: Cow<'a, str>, + salt: Cow<'a, str>, + iterations: usize, +} + +impl ServerFirstResponse<'_> { + pub fn to_owned(&self) -> ServerFirstResponse<'static> { + ServerFirstResponse { + combined_nonce: self.combined_nonce.to_string().into(), + salt: self.salt.to_string().into(), + iterations: self.iterations, + } + } +} + +impl Encode for ServerFirstResponse<'_> { + fn encode(&self) -> String { + format!( + "r={},s={},i={}", + self.combined_nonce, self.salt, self.iterations + ) + } +} + +impl<'a> Decode<'a> for ServerFirstResponse<'a> { + fn decode(buf: &'a [u8]) -> Result { + let mut parts = buf.split(|&b| b == b','); + let combined_nonce = extract(inext(&mut parts)?, "r=")?.into(); + let salt = extract(inext(&mut parts)?, "s=")?.into(); + let iterations = extract(inext(&mut parts)?, "i=")?; + Ok(ServerFirstResponse { + combined_nonce, + salt, + iterations: str::parse(iterations).map_err(|_| SCRAMError::ProtocolError)?, + }) + } +} + +pub struct ClientFinalMessage<'a> { + channel_binding: Cow<'a, str>, + combined_nonce: Cow<'a, str>, + proof: Cow<'a, str>, +} + +impl Encode for ClientFinalMessage<'_> { + fn encode(&self) -> String { + format!( + "c={},r={},p={}", + self.channel_binding, self.combined_nonce, self.proof + ) + } +} + +impl<'a> Decode<'a> for ClientFinalMessage<'a> { + fn decode(buf: &'a [u8]) -> Result { + let mut parts = buf.split(|&b| b == b','); + let channel_binding = extract(inext(&mut parts)?, "c=")?.into(); + let combined_nonce = extract(inext(&mut parts)?, "r=")?.into(); + let proof = extract(inext(&mut parts)?, "p=")?.into(); + Ok(ClientFinalMessage { + channel_binding, + combined_nonce, + proof, + }) + } +} + +#[derive(Debug)] +pub struct ServerFinalResponse<'a> { + verifier: Cow<'a, str>, +} + +impl<'a> Encode for ServerFinalResponse<'a> { + fn encode(&self) -> String { + format!("v={}", self.verifier) + } +} + +impl<'a> Decode<'a> for ServerFinalResponse<'a> { + fn decode(buf: &'a [u8]) -> Result { + let mut parts = buf.split(|&b| b == b','); + let verifier = extract(inext(&mut parts)?, "v=")?.into(); + Ok(ServerFinalResponse { verifier }) + } +} + +pub fn decode_salt<'a>(salt: &str, buffer: &'a mut [u8]) -> Result, SCRAMError> { + // The salt needs to be base64 decoded -- full binary must be used + if let Ok(n) = BASE64_STANDARD.decode_slice(salt, buffer) { + Ok(Cow::Borrowed(&buffer[..n])) + } else { + // In the unlikely case the salt is large -- note that we also fall back to this + // path for invalid base64 strings! + let mut buffer = vec![]; + BASE64_STANDARD + .decode_vec(salt, &mut buffer) + .map_err(|_| SCRAMError::ProtocolError)?; + Ok(Cow::Owned(buffer)) + } +} + +/// Given a password in byte form, generates the salted version of the password, +/// applying SASLprep to it beforehand. +pub fn generate_salted_password(password: &[u8], salt: &[u8], iterations: usize) -> Sha256Out { + // Save the pre-keyed hmac + let ui_p = hmac(&sasl_normalize_password_bytes(password)); + + // The initial signature is the salt with a terminator of a 32-bit string ending in 1 + let mut ui = ui_p.clone(); + + ui.update(salt); + ui.update(&[0, 0, 0, 1]); + + // Grab the initial digest + let mut last_hash = Default::default(); + ui.finalize_into(&mut last_hash); + let mut u = last_hash; + + // For X number of iterations, recompute the HMAC signature against the password and the latest iteration of the hash, and XOR it with the previous version + for _ in 0..(iterations - 1) { + let mut ui = ui_p.clone(); + ui.update(&last_hash); + ui.finalize_into(&mut last_hash); + for i in 0..u.len() { + u[i] ^= last_hash[i]; + } + } + + u.as_slice().try_into().unwrap() +} + +/// Generate a stored key compatible with PostgreSQL's encoding. +pub fn generate_stored_key(password: &[u8], salt: &[u8], iterations: usize) -> String { + let digest_key = generate_salted_password(password, salt, iterations); + + let mut client_key = hmac(&digest_key) + .chain_update(b"Client Key") + .finalize() + .into_bytes(); + + let stored_key = Sha256::digest(client_key); + + let server_key = hmac(&digest_key) + .chain_update(b"Server Key") + .finalize() + .into_bytes(); + + format!( + "SCRAM-SHA-256${}:{}${}:{}", + iterations, + BASE64_STANDARD.encode(salt), + BASE64_STANDARD.encode(stored_key), + BASE64_STANDARD.encode(server_key) + ) +} + +fn generate_proof( + first_message_bare: &[u8], + server_first_message: &[u8], + channel_binding: &[u8], + server_nonce: &[u8], + salted_password: &[u8], +) -> (Sha256Out, Sha256Out) { + let ui_p = hmac(salted_password); + + let mut ui = ui_p.clone(); + ui.update(b"Server Key"); + let server_key = ui.finalize_fixed(); + + let mut ui = ui_p.clone(); + ui.update(b"Client Key"); + let client_key = ui.finalize_fixed(); + + let mut hash = Sha256::new(); + hash.update(client_key); + let stored_key = hash.finalize_fixed(); + + let auth_message = [ + (first_message_bare), + (b","), + (server_first_message), + (b",c="), + (channel_binding), + (b",r="), + (server_nonce), + ]; + + let mut client_signature = hmac(&stored_key); + for chunk in auth_message { + client_signature.update(chunk); + } + + let client_signature = client_signature.finalize_fixed(); + let mut client_signature: Sha256Out = client_signature.as_slice().try_into().unwrap(); + + for i in 0..client_signature.len() { + client_signature[i] ^= client_key[i]; + } + + let mut server_proof = hmac(&server_key); + for chunk in auth_message { + server_proof.update(chunk); + } + let server_proof = server_proof.finalize_fixed().as_slice().try_into().unwrap(); + + (client_signature, server_proof) +} + +#[cfg(test)] +mod tests { + use super::*; + use hex_literal::hex; + use rstest::rstest; + + // Define a set of test parameters + const CLIENT_NONCE: &str = "2XendqvQOa6cl0+Q7Y6UU0gw"; + const SERVER_NONCE: &str = "xWn3mvDeVZwnUtT09vwXoItO"; + const USERNAME: &str = ""; + const PASSWORD: &[u8] = b"secret"; + const SALT: &str = "t5YekvL6lgy4RyPnsiyqsg=="; + const ITERATIONS: usize = 4096; + const CLIENT_PROOF: &[u8] = "p/HmDcOziQQnyF8fbVnJnlvwoLp1kZY4xsI9cCJhzCE=".as_bytes(); + const SERVER_VERIFY: &[u8] = "g/X0codOryF0nCOWh7KkIab23ZFPX99iLzN5Ghn3nNc=".as_bytes(); + + #[rstest] + #[case( + b"1234", + "1234", + 1, + hex!("EBE7E5BA4BF5A4D178D3BADAADD4C49A98C72FCFF4FB357DA7090D584990FCAA") + )] + #[case( + b"1234", + "1234", + 2, + hex!("F9271C334EE6CD7FEE63BBC86FAF951A4ED9E293BDD72AC33663BAE662D31953") + )] + #[case( + b"1234", + "1234", + 4096, + hex!("4FF8D6443278AB43209DF5A1327949AAC99A5AA23921E5C9199626524776F751") + )] + #[case( + b"password", + "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu", + 4096, + hex!("E118A9AD43C87938659AD736E63F26BA2EBAF079AA351DB44AE29228FB4F7EF0") + )] + #[case( + b"secret", + "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu", + 4096, + hex!("77DFD8E62A4379296C9769F9BA2F77D503C4647DE7919B47D6CF121986981BCC") + )] + #[case( + b"secret", + "t5YekvL6lgy4RyPnsiyqsg==", + 4096, + hex!("9FB413FE9F1D0C8020400A3D49CFBC47FBFB1251CEA9297630BD025DB2B65171") + )] + #[case( + "😀".as_bytes(), + "t5YekvL6lgy4RyPnsiyqsg==", + 4096, + hex!("AF490CE1BEA2DDB585DAF9C3842D1528AB091EF6FAB2A92489870523A98835EE") + )] + fn test_generate_salted_password( + #[case] password: &[u8], + #[case] salt: &str, + #[case] iterations: usize, + #[case] expected_hash: Sha256Out, + ) { + let mut buffer = [0; 128]; + let salt = decode_salt(salt, &mut buffer).unwrap(); + let hash = generate_salted_password(password, &salt, iterations); + assert_eq!(hash, expected_hash); + } + + /// Tests that use real stored keys from postgres to match normalization + /// behaviour. This exercises the saslprep code. + #[rstest] + // ASCII + #[case(b"password", "SCRAM-SHA-256$4096:jZLwuMbICV2L8i9SsfSEYQ==$Qhd2nOIlLW/dtVFERkVjVNdzzrVwPm2l+WHibmPesoc=:P1aH2cUHyPUbIdO06hEiXdwKxQyqBNUijLGFLkTXcHs=")] + // Unicode + #[case("schön".as_bytes(), "SCRAM-SHA-256$4096:uuH6VXsbbeId2AcdL0WmSA==$imMseND/Sg7tL5Tm1ltZJGa6PsdxwysUZ9s1lXPOPdo=:kMp6Rb9yN3zYpvwkuf0/xQZWhIGEa0ryjwnyDfpL3G0=")] + // Unicode normalization -> half-width to full-width + #[case("パスワード".as_bytes(), "SCRAM-SHA-256$4096:oCSGmW9Llo803DWp94yE0A==$TvNA2Hh1IqwCHlhxHhIaTeI7N/mFSx01D3/tb2VGQfw=:RBDsZImb7XoP6Md1j0zhjf7yBz0ocDoxqsPeFtJLyaI=")] + // Chars that normalize to space and nothing + #[case(b"pass\xc2\xa0\xe2\x80\x80word", "SCRAM-SHA-256$4096:ag3Z1WnqEn8dhTvSP7UtYA==$taWe9cZJYK5Y28V9Nw3zy6E9qQKbqKrMRS5DwlDXG04=:Y4n3uwZ4jQyG7nYCde3vtPxO1p0Oxz5ytJT1W+lqM+I=")] + // Invalid control chars + #[case(b"\x01\x02\x03", "SCRAM-SHA-256$4096:XGcYpEn2cwuS+BZXJBaqFg==$mG53wGoI6pAANoAZl7qxYiKPZ6u3CfhCVZK4et3l52A=:X5PUFkC5MVJWmuBTwWQHTFH81xjiyAHrJ9r0anOPXiI=")] + // Prohibited char (ffff) + #[case(b"\xef\xbf\xbf", "SCRAM-SHA-256$4096:Tdv5eCJIm+LU9QJBKO96gQ==$YXE4G3HKPwCmwo4FjiFKaiqVGCDTOpVETv+Fe6wWY9Q=:DK7MZ/OgGGgCDh6EfsmmcyFuaAD+T2Zh78sl+QDQFIo=")] + fn test_stored_key(#[case] password: &[u8], #[case] stored_key: &str) { + use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine}; + use hmac::{Hmac, Mac}; + use sha2::{Digest, Sha256}; + + let salt = stored_key + .split(':') + .nth(1) + .unwrap() + .split('$') + .next() + .unwrap(); + let salt = BASE64_STANDARD.decode(salt).unwrap(); + let generated_key = generate_stored_key(password, &salt, 4096); + assert_eq!(generated_key, stored_key); + } + + #[test] + fn test_client_proof() { + let mut buffer = [0; 128]; + let salt = decode_salt(SALT, &mut buffer).unwrap(); + let salted_password = generate_salted_password(PASSWORD, &salt, ITERATIONS); + let (client, server) = generate_proof( + format!("n={USERNAME},r={CLIENT_NONCE}").as_bytes(), + format!("r={CLIENT_NONCE}{SERVER_NONCE},s={SALT},i={ITERATIONS}").as_bytes(), + CHANNEL_BINDING_ENCODED.as_bytes(), + format!("{CLIENT_NONCE}{SERVER_NONCE}").as_bytes(), + &salted_password, + ); + assert_eq!( + &client, + BASE64_STANDARD.decode(CLIENT_PROOF).unwrap().as_slice() + ); + assert_eq!( + &server, + BASE64_STANDARD.decode(SERVER_VERIFY).unwrap().as_slice() + ); + } + + #[test] + fn test_client_first_message() { + let message = ClientFirstMessage::decode(b"n,,n=,r=480I9uIaXEU9oB2RRcenOxN/").unwrap(); + assert_eq!( + message.channel_binding, + ChannelBinding::NotSupported(Cow::Borrowed("")) + ); + assert_eq!(message.username, ""); + assert_eq!(message.nonce, "480I9uIaXEU9oB2RRcenOxN/"); + assert_eq!( + message.encode(), + "n,,n=,r=480I9uIaXEU9oB2RRcenOxN/".to_owned() + ); + } + + #[test] + fn test_client_first_message_required() { + let message = + ClientFirstMessage::decode(b"p=cb-name,,n=,r=480I9uIaXEU9oB2RRcenOxN/").unwrap(); + assert_eq!( + message.channel_binding, + ChannelBinding::Required(Cow::Borrowed(""), Cow::Borrowed("cb-name")) + ); + assert_eq!(message.username, ""); + assert_eq!(message.nonce, "480I9uIaXEU9oB2RRcenOxN/"); + assert_eq!( + message.encode(), + "p=cb-name,,n=,r=480I9uIaXEU9oB2RRcenOxN/".to_owned() + ); + } + + #[test] + fn test_server_first_response() { + let message = ServerFirstResponse::decode( + b"r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,s=t5YekvL6lgy4RyPnsiyqsg==,i=4096", + ) + .unwrap(); + assert_eq!( + message.combined_nonce, + "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu" + ); + assert_eq!(message.salt, "t5YekvL6lgy4RyPnsiyqsg=="); + assert_eq!(message.iterations, 4096); + assert_eq!( + message.encode(), + "r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,s=t5YekvL6lgy4RyPnsiyqsg==,i=4096" + .to_owned() + ); + } + + #[test] + fn test_client_final_message() { + let message = b"c=biws,r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,p=7Vkz4SfWTNhB3hNdhTucC+3MaGmg3+PrAG3xfuepjP4="; + let decoded = ClientFinalMessage::decode(message).unwrap(); + assert_eq!(decoded.channel_binding, "biws"); + assert_eq!( + decoded.combined_nonce, + "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu" + ); + assert_eq!( + decoded.proof, + "7Vkz4SfWTNhB3hNdhTucC+3MaGmg3+PrAG3xfuepjP4=" + ); + let encoded = decoded.encode(); + assert_eq!(encoded, "c=biws,r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,p=7Vkz4SfWTNhB3hNdhTucC+3MaGmg3+PrAG3xfuepjP4="); + } + + #[test] + fn test_server_final_response() { + let message = b"v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="; + let decoded: ServerFinalResponse = ServerFinalResponse::decode(message).unwrap(); + assert_eq!( + decoded.verifier, + "6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=" + ); + let encoded = decoded.encode(); + assert_eq!(encoded, "v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="); + } + + /// Run a SCRAM conversation with a fixed set of parameters + #[test] + fn test_transaction() { + let mut server = ServerTransaction::default(); + let mut client = ClientTransaction::new("username".into()); + + struct Env {} + impl ClientEnvironment for Env { + fn generate_nonce(&self) -> String { + "<<>>".into() + } + fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out { + generate_salted_password(b"password", salt, iterations) + } + } + impl ServerEnvironment for Env { + fn get_salted_password(&self, username: &str) -> Sha256Out { + assert_eq!(username, "username"); + generate_salted_password(b"password", b"hello", 4096) + } + fn generate_nonce(&self) -> String { + "<<>>".into() + } + fn get_password_parameters(&self, username: &str) -> (Cow<'static, str>, usize) { + assert_eq!(username, "username"); + (Cow::Borrowed("aGVsbG8="), 4096) + } + } + let env = Env {}; + + let message = client.process_message(&[], &env).unwrap().unwrap(); + eprintln!("client: {:?}", String::from_utf8(message.clone()).unwrap()); + let message = server.process_message(&message, &env).unwrap().unwrap(); + eprintln!("server: {:?}", String::from_utf8(message.clone()).unwrap()); + let message = client.process_message(&message, &env).unwrap().unwrap(); + eprintln!("client: {:?}", String::from_utf8(message.clone()).unwrap()); + let message = server.process_message(&message, &env).unwrap().unwrap(); + eprintln!("server: {:?}", String::from_utf8(message.clone()).unwrap()); + assert!(client.process_message(&message, &env).unwrap().is_none()); + assert!(client.success()); + assert!(server.success()); + } +} diff --git a/edb/server/pgrust/src/auth/stringprep.rs b/edb/server/pgrust/src/auth/stringprep.rs new file mode 100644 index 000000000000..67c57d9b227e --- /dev/null +++ b/edb/server/pgrust/src/auth/stringprep.rs @@ -0,0 +1,209 @@ +use core::str; +use roaring::RoaringBitmap; +use std::{ops::Range, sync::OnceLock}; +use unicode_normalization::UnicodeNormalization; + +/// Normalize the password using the SASLprep algorithm from RFC4013. +/// +/// # Examples +/// +/// ``` +/// use pgrust::auth::sasl_normalize_password_bytes; +/// +/// assert_eq!(sasl_normalize_password_bytes(b"password").as_ref(), b"password"); +/// assert_eq!(sasl_normalize_password_bytes("passw\u{00A0}rd".as_bytes()).as_ref(), b"passw rd"); +/// assert_eq!(sasl_normalize_password_bytes("pass\u{200B}word".as_bytes()).as_ref(), b"password"); +/// // This test case demonstrates that invalid UTF-8 sequences are returned unchanged. +/// // The bytes 0xFF, 0xFE, 0xFD do not form a valid UTF-8 sequence, so the function +/// // should return them as-is without attempting to normalize or modify them. +/// assert_eq!(sasl_normalize_password_bytes(&[0xFF, 0xFE, 0xFD]).as_ref(), &[0xFF, 0xFE, 0xFD]); +/// ``` +pub fn sasl_normalize_password_bytes(s: &[u8]) -> Cow<[u8]> { + if s.is_ascii() { + Cow::Borrowed(s) + } else if let Ok(s) = str::from_utf8(s) { + match sasl_normalize_password(s) { + Cow::Borrowed(s) => Cow::Borrowed(s.as_bytes()), + Cow::Owned(s) => Cow::Owned(s.into()), + } + } else { + Cow::Borrowed(s) + } +} + +/// Normalize the password using the SASLprep from RFC4013. +/// +/// # Examples +/// +/// ``` +/// use pgrust::auth::sasl_normalize_password; +/// +/// assert_eq!(sasl_normalize_password("password").as_ref(), "password"); +/// assert_eq!(sasl_normalize_password("passw\u{00A0}rd").as_ref(), "passw rd"); +/// assert_eq!(sasl_normalize_password("pass\u{200B}word").as_ref(), "password"); +/// assert_eq!(sasl_normalize_password("パスワード").as_ref(), "パスワード"); // precomposed Japanese +/// assert_eq!(sasl_normalize_password("パスワード").as_ref(), "パスワード"); // half-width to full-width katakana +/// assert_eq!(sasl_normalize_password("\u{0061}\u{0308}"), "\u{00E4}"); // a + combining diaeresis -> ä +/// assert_eq!(sasl_normalize_password("\u{00E4}"), "\u{00E4}"); // precomposed ä +/// assert_eq!(sasl_normalize_password("\u{0041}\u{0308}"), "\u{00C4}"); // A + combining diaeresis -> Ä +/// assert_eq!(sasl_normalize_password("\u{00C4}"), "\u{00C4}"); // precomposed Ä +/// assert_eq!(sasl_normalize_password("\u{0627}\u{0644}\u{0639}\u{0631}\u{0628}\u{064A}\u{0629}"), "\u{0627}\u{0644}\u{0639}\u{0631}\u{0628}\u{064A}\u{0629}"); // Arabic (RandALCat) +/// ``` +pub fn sasl_normalize_password(s: &str) -> Cow { + if s.is_ascii() { + return Cow::Borrowed(s); + } + + let mut normalized = String::with_capacity(s.len()); + + // Step 1 of SASLPrep: Map. Per the algorithm, we map non-ascii space + // characters to ASCII spaces (\x20 or \u0020, but we will use ' ') and + // commonly mapped to nothing characters are removed + // Table C.1.2 -- non-ASCII spaces + // Table B.1 -- "Commonly mapped to nothing" + for c in s.chars() { + if !maps_to_nothing::is_char_included(c as u32) { + if maps_to_space::is_char_included(c as u32) { + normalized.push(' '); + } else { + normalized.push(c); + } + } + } + + // If at this point the password is empty, PostgreSQL uses the original + // password + if normalized.is_empty() { + return Cow::Borrowed(s); + } + + // Step 2 of SASLPrep: Normalize. Normalize the password using the + // Unicode normalization algorithm to NFKC form + let normalized = normalized.chars().nfkc().collect::(); + + // If the password is not empty, PostgreSQL uses the original password + if normalized.is_empty() { + return Cow::Borrowed(s); + } + + // Step 3 of SASLPrep: Prohibited characters. If PostgreSQL detects any + // of the prohibited characters in SASLPrep, it will use the original + // password + // We also include "unassigned code points" in the prohibited character + // category as PostgreSQL does the same + if normalized.chars().any(is_saslprep_prohibited) { + return Cow::Borrowed(s); + } + + // Step 4 of SASLPrep: Bi-directional characters. PostgreSQL follows the + // rules for bi-directional characters laid on in RFC3454 Sec. 6 which + // are: + // 1. Characters in RFC 3454 Sec 5.8 are prohibited (C.8) + // 2. If a string contains a RandALCat character, it cannot contain any + // LCat character + // 3. If the string contains any RandALCat character, a RandALCat + // character must be the first and last character of the string + // RandALCat characters are found in table D.1, whereas LCat are in D.2. + // A RandALCat character is a character with unambiguously right-to-left + // directionality. + let first_char = normalized.chars().next().unwrap(); + let last_char = normalized.chars().last().unwrap(); + + let contains_rand_al_cat = normalized + .chars() + .any(|c| table_d1::is_char_included(c as u32)); + if contains_rand_al_cat { + let contains_l_cat = normalized + .chars() + .any(|c| table_d2::is_char_included(c as u32)); + if !table_d1::is_char_included(first_char as u32) + || !table_d1::is_char_included(last_char as u32) + || contains_l_cat + { + return Cow::Borrowed(s); + } + } + + // return the normalized password + Cow::Owned(normalized) +} + +#[macro_export] +macro_rules! __process_ranges { + ( + $name:ident => + $( ($first:literal, $last:literal) )* + ) => { + pub mod $name { + #[allow(unused)] + pub const RANGES: [std::ops::Range; [$($first),*].len()] = [ + $( + $first..$last, + )* + ]; + + #[allow(non_contiguous_range_endpoints)] + #[allow(unused)] + pub fn is_char_included(c: u32) -> bool { + match c { + $( + $first..$last => true, + )* + _ => false, + } + } + } + }; +} +use std::borrow::Cow; + +pub(crate) use __process_ranges as process_ranges; + +use super::stringprep_table::{maps_to_nothing, maps_to_space, not_prohibited, table_d1, table_d2}; + +fn create_bitmap_from_ranges(ranges: &[Range]) -> RoaringBitmap { + let mut bitmap = RoaringBitmap::new(); + for range in ranges { + bitmap.insert_range(range.clone()); + } + bitmap +} + +static NOT_PROHIBITED_BITMAP: std::sync::OnceLock = OnceLock::new(); + +fn get_not_prohibited_bitmap() -> &'static RoaringBitmap { + NOT_PROHIBITED_BITMAP.get_or_init(|| create_bitmap_from_ranges(¬_prohibited::RANGES)) +} + +#[inline(always)] +fn is_saslprep_prohibited(c: char) -> bool { + !get_not_prohibited_bitmap().contains(c as u32) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_prohibited() { + assert!(is_saslprep_prohibited('\0')); + assert!(is_saslprep_prohibited('\u{100000}')); + } + + #[test] + fn generate_roaring_bitmap() { + let bitmap = create_bitmap_from_ranges(¬_prohibited::RANGES); + + // You can save the bitmap to a file or use it in other ways + // For example, to save it to a file: + // use std::fs::File; + // use std::io::BufWriter; + // let file = File::create("saslprep_prohibited.bin").unwrap(); + // let mut writer = BufWriter::new(file); + // bitmap.serialize_into(&mut writer).unwrap(); + + // Print some statistics about the bitmap + println!("Bitmap cardinality: {}", bitmap.len()); + println!("Bitmap size in bytes: {}", bitmap.serialized_size()); + } +} diff --git a/edb/server/pgrust/src/auth/stringprep_table.rs b/edb/server/pgrust/src/auth/stringprep_table.rs new file mode 100644 index 000000000000..d489235e2727 --- /dev/null +++ b/edb/server/pgrust/src/auth/stringprep_table.rs @@ -0,0 +1,801 @@ +super::stringprep::process_ranges!(not_prohibited => +(0x20, 0x7f) +(0xa1, 0x221) +(0x222, 0x234) +(0x250, 0x2ae) +(0x2b0, 0x2ef) +(0x300, 0x340) +(0x342, 0x350) +(0x360, 0x370) +(0x374, 0x376) +(0x37a, 0x37b) +(0x37e, 0x37f) +(0x384, 0x38b) +(0x38c, 0x38d) +(0x38e, 0x3a2) +(0x3a3, 0x3cf) +(0x3d0, 0x3f7) +(0x400, 0x487) +(0x488, 0x4cf) +(0x4d0, 0x4f6) +(0x4f8, 0x4fa) +(0x500, 0x510) +(0x531, 0x557) +(0x559, 0x560) +(0x561, 0x588) +(0x589, 0x58b) +(0x591, 0x5a2) +(0x5a3, 0x5ba) +(0x5bb, 0x5c5) +(0x5d0, 0x5eb) +(0x5f0, 0x5f5) +(0x60c, 0x60d) +(0x61b, 0x61c) +(0x61f, 0x620) +(0x621, 0x63b) +(0x640, 0x656) +(0x660, 0x6dd) +(0x6de, 0x6ee) +(0x6f0, 0x6ff) +(0x700, 0x70e) +(0x710, 0x72d) +(0x730, 0x74b) +(0x780, 0x7b2) +(0x901, 0x904) +(0x905, 0x93a) +(0x93c, 0x94e) +(0x950, 0x955) +(0x958, 0x971) +(0x981, 0x984) +(0x985, 0x98d) +(0x98f, 0x991) +(0x993, 0x9a9) +(0x9aa, 0x9b1) +(0x9b2, 0x9b3) +(0x9b6, 0x9ba) +(0x9bc, 0x9bd) +(0x9be, 0x9c5) +(0x9c7, 0x9c9) +(0x9cb, 0x9ce) +(0x9d7, 0x9d8) +(0x9dc, 0x9de) +(0x9df, 0x9e4) +(0x9e6, 0x9fb) +(0xa02, 0xa03) +(0xa05, 0xa0b) +(0xa0f, 0xa11) +(0xa13, 0xa29) +(0xa2a, 0xa31) +(0xa32, 0xa34) +(0xa35, 0xa37) +(0xa38, 0xa3a) +(0xa3c, 0xa3d) +(0xa3e, 0xa43) +(0xa47, 0xa49) +(0xa4b, 0xa4e) +(0xa59, 0xa5d) +(0xa5e, 0xa5f) +(0xa66, 0xa75) +(0xa81, 0xa84) +(0xa85, 0xa8c) +(0xa8d, 0xa8e) +(0xa8f, 0xa92) +(0xa93, 0xaa9) +(0xaaa, 0xab1) +(0xab2, 0xab4) +(0xab5, 0xaba) +(0xabc, 0xac6) +(0xac7, 0xaca) +(0xacb, 0xace) +(0xad0, 0xad1) +(0xae0, 0xae1) +(0xae6, 0xaf0) +(0xb01, 0xb04) +(0xb05, 0xb0d) +(0xb0f, 0xb11) +(0xb13, 0xb29) +(0xb2a, 0xb31) +(0xb32, 0xb34) +(0xb36, 0xb3a) +(0xb3c, 0xb44) +(0xb47, 0xb49) +(0xb4b, 0xb4e) +(0xb56, 0xb58) +(0xb5c, 0xb5e) +(0xb5f, 0xb62) +(0xb66, 0xb71) +(0xb82, 0xb84) +(0xb85, 0xb8b) +(0xb8e, 0xb91) +(0xb92, 0xb96) +(0xb99, 0xb9b) +(0xb9c, 0xb9d) +(0xb9e, 0xba0) +(0xba3, 0xba5) +(0xba8, 0xbab) +(0xbae, 0xbb6) +(0xbb7, 0xbba) +(0xbbe, 0xbc3) +(0xbc6, 0xbc9) +(0xbca, 0xbce) +(0xbd7, 0xbd8) +(0xbe7, 0xbf3) +(0xc01, 0xc04) +(0xc05, 0xc0d) +(0xc0e, 0xc11) +(0xc12, 0xc29) +(0xc2a, 0xc34) +(0xc35, 0xc3a) +(0xc3e, 0xc45) +(0xc46, 0xc49) +(0xc4a, 0xc4e) +(0xc55, 0xc57) +(0xc60, 0xc62) +(0xc66, 0xc70) +(0xc82, 0xc84) +(0xc85, 0xc8d) +(0xc8e, 0xc91) +(0xc92, 0xca9) +(0xcaa, 0xcb4) +(0xcb5, 0xcba) +(0xcbe, 0xcc5) +(0xcc6, 0xcc9) +(0xcca, 0xcce) +(0xcd5, 0xcd7) +(0xcde, 0xcdf) +(0xce0, 0xce2) +(0xce6, 0xcf0) +(0xd02, 0xd04) +(0xd05, 0xd0d) +(0xd0e, 0xd11) +(0xd12, 0xd29) +(0xd2a, 0xd3a) +(0xd3e, 0xd44) +(0xd46, 0xd49) +(0xd4a, 0xd4e) +(0xd57, 0xd58) +(0xd60, 0xd62) +(0xd66, 0xd70) +(0xd82, 0xd84) +(0xd85, 0xd97) +(0xd9a, 0xdb2) +(0xdb3, 0xdbc) +(0xdbd, 0xdbe) +(0xdc0, 0xdc7) +(0xdca, 0xdcb) +(0xdcf, 0xdd5) +(0xdd6, 0xdd7) +(0xdd8, 0xde0) +(0xdf2, 0xdf5) +(0xe01, 0xe3b) +(0xe3f, 0xe5c) +(0xe81, 0xe83) +(0xe84, 0xe85) +(0xe87, 0xe89) +(0xe8a, 0xe8b) +(0xe8d, 0xe8e) +(0xe94, 0xe98) +(0xe99, 0xea0) +(0xea1, 0xea4) +(0xea5, 0xea6) +(0xea7, 0xea8) +(0xeaa, 0xeac) +(0xead, 0xeba) +(0xebb, 0xebe) +(0xec0, 0xec5) +(0xec6, 0xec7) +(0xec8, 0xece) +(0xed0, 0xeda) +(0xedc, 0xede) +(0xf00, 0xf48) +(0xf49, 0xf6b) +(0xf71, 0xf8c) +(0xf90, 0xf98) +(0xf99, 0xfbd) +(0xfbe, 0xfcd) +(0xfcf, 0xfd0) +(0x1000, 0x1022) +(0x1023, 0x1028) +(0x1029, 0x102b) +(0x102c, 0x1033) +(0x1036, 0x103a) +(0x1040, 0x105a) +(0x10a0, 0x10c6) +(0x10d0, 0x10f9) +(0x10fb, 0x10fc) +(0x1100, 0x115a) +(0x115f, 0x11a3) +(0x11a8, 0x11fa) +(0x1200, 0x1207) +(0x1208, 0x1247) +(0x1248, 0x1249) +(0x124a, 0x124e) +(0x1250, 0x1257) +(0x1258, 0x1259) +(0x125a, 0x125e) +(0x1260, 0x1287) +(0x1288, 0x1289) +(0x128a, 0x128e) +(0x1290, 0x12af) +(0x12b0, 0x12b1) +(0x12b2, 0x12b6) +(0x12b8, 0x12bf) +(0x12c0, 0x12c1) +(0x12c2, 0x12c6) +(0x12c8, 0x12cf) +(0x12d0, 0x12d7) +(0x12d8, 0x12ef) +(0x12f0, 0x130f) +(0x1310, 0x1311) +(0x1312, 0x1316) +(0x1318, 0x131f) +(0x1320, 0x1347) +(0x1348, 0x135b) +(0x1361, 0x137d) +(0x13a0, 0x13f5) +(0x1401, 0x1677) +(0x1681, 0x169d) +(0x16a0, 0x16f1) +(0x1700, 0x170d) +(0x170e, 0x1715) +(0x1720, 0x1737) +(0x1740, 0x1754) +(0x1760, 0x176d) +(0x176e, 0x1771) +(0x1772, 0x1774) +(0x1780, 0x17dd) +(0x17e0, 0x17ea) +(0x1800, 0x180e) +(0x1810, 0x181a) +(0x1820, 0x1878) +(0x1880, 0x18aa) +(0x1e00, 0x1e9c) +(0x1ea0, 0x1efa) +(0x1f00, 0x1f16) +(0x1f18, 0x1f1e) +(0x1f20, 0x1f46) +(0x1f48, 0x1f4e) +(0x1f50, 0x1f58) +(0x1f59, 0x1f5a) +(0x1f5b, 0x1f5c) +(0x1f5d, 0x1f5e) +(0x1f5f, 0x1f7e) +(0x1f80, 0x1fb5) +(0x1fb6, 0x1fc5) +(0x1fc6, 0x1fd4) +(0x1fd6, 0x1fdc) +(0x1fdd, 0x1ff0) +(0x1ff2, 0x1ff5) +(0x1ff6, 0x1fff) +(0x2010, 0x2028) +(0x2030, 0x2053) +(0x2057, 0x2058) +(0x2070, 0x2072) +(0x2074, 0x208f) +(0x20a0, 0x20b2) +(0x20d0, 0x20eb) +(0x2100, 0x213b) +(0x213d, 0x214c) +(0x2153, 0x2184) +(0x2190, 0x23cf) +(0x2400, 0x2427) +(0x2440, 0x244b) +(0x2460, 0x24ff) +(0x2500, 0x2614) +(0x2616, 0x2618) +(0x2619, 0x267e) +(0x2680, 0x268a) +(0x2701, 0x2705) +(0x2706, 0x270a) +(0x270c, 0x2728) +(0x2729, 0x274c) +(0x274d, 0x274e) +(0x274f, 0x2753) +(0x2756, 0x2757) +(0x2758, 0x275f) +(0x2761, 0x2795) +(0x2798, 0x27b0) +(0x27b1, 0x27bf) +(0x27d0, 0x27ec) +(0x27f0, 0x2b00) +(0x2e80, 0x2e9a) +(0x2e9b, 0x2ef4) +(0x2f00, 0x2fd6) +(0x3001, 0x3040) +(0x3041, 0x3097) +(0x3099, 0x3100) +(0x3105, 0x312d) +(0x3131, 0x318f) +(0x3190, 0x31b8) +(0x31f0, 0x321d) +(0x3220, 0x3244) +(0x3251, 0x327c) +(0x327f, 0x32cc) +(0x32d0, 0x32ff) +(0x3300, 0x3377) +(0x337b, 0x33de) +(0x33e0, 0x33ff) +(0x3400, 0x4db6) +(0x4e00, 0x9fa6) +(0xa000, 0xa48d) +(0xa490, 0xa4c7) +(0xac00, 0xd7a4) +(0xf900, 0xfa2e) +(0xfa30, 0xfa6b) +(0xfb00, 0xfb07) +(0xfb13, 0xfb18) +(0xfb1d, 0xfb37) +(0xfb38, 0xfb3d) +(0xfb3e, 0xfb3f) +(0xfb40, 0xfb42) +(0xfb43, 0xfb45) +(0xfb46, 0xfbb2) +(0xfbd3, 0xfd40) +(0xfd50, 0xfd90) +(0xfd92, 0xfdc8) +(0xfdf0, 0xfdfd) +(0xfe00, 0xfe10) +(0xfe20, 0xfe24) +(0xfe30, 0xfe47) +(0xfe49, 0xfe53) +(0xfe54, 0xfe67) +(0xfe68, 0xfe6c) +(0xfe70, 0xfe75) +(0xfe76, 0xfefd) +(0xff01, 0xffbf) +(0xffc2, 0xffc8) +(0xffca, 0xffd0) +(0xffd2, 0xffd8) +(0xffda, 0xffdd) +(0xffe0, 0xffe7) +(0xffe8, 0xffef) +(0x10300, 0x1031f) +(0x10320, 0x10324) +(0x10330, 0x1034b) +(0x10400, 0x10426) +(0x10428, 0x1044e) +(0x1d000, 0x1d0f6) +(0x1d100, 0x1d127) +(0x1d12a, 0x1d173) +(0x1d17b, 0x1d1de) +(0x1d400, 0x1d455) +(0x1d456, 0x1d49d) +(0x1d49e, 0x1d4a0) +(0x1d4a2, 0x1d4a3) +(0x1d4a5, 0x1d4a7) +(0x1d4a9, 0x1d4ad) +(0x1d4ae, 0x1d4ba) +(0x1d4bb, 0x1d4bc) +(0x1d4bd, 0x1d4c1) +(0x1d4c2, 0x1d4c4) +(0x1d4c5, 0x1d506) +(0x1d507, 0x1d50b) +(0x1d50d, 0x1d515) +(0x1d516, 0x1d51d) +(0x1d51e, 0x1d53a) +(0x1d53b, 0x1d53f) +(0x1d540, 0x1d545) +(0x1d546, 0x1d547) +(0x1d54a, 0x1d551) +(0x1d552, 0x1d6a4) +(0x1d6a8, 0x1d7ca) +(0x1d7ce, 0x1d800) +(0x20000, 0x2a6d7) +(0x2f800, 0x2fa1e) +); +super::stringprep::process_ranges!(maps_to_space => +(0xa0, 0xa1) +(0x1680, 0x1681) +(0x2000, 0x200c) +(0x202f, 0x2030) +(0x205f, 0x2060) +(0x3000, 0x3001) +); +super::stringprep::process_ranges!(maps_to_nothing => +(0xad, 0xae) +(0x34f, 0x350) +(0x1806, 0x1807) +(0x180b, 0x180e) +(0x200b, 0x200e) +(0x2060, 0x2061) +(0xfe00, 0xfe10) +(0xfeff, 0xff00) +); +super::stringprep::process_ranges!(table_d1 => +(0x5be, 0x5bf) +(0x5c0, 0x5c1) +(0x5c3, 0x5c4) +(0x5d0, 0x5eb) +(0x5f0, 0x5f5) +(0x61b, 0x61c) +(0x61f, 0x620) +(0x621, 0x63b) +(0x640, 0x64b) +(0x66d, 0x670) +(0x671, 0x6d6) +(0x6dd, 0x6de) +(0x6e5, 0x6e7) +(0x6fa, 0x6ff) +(0x700, 0x70e) +(0x710, 0x711) +(0x712, 0x72d) +(0x780, 0x7a6) +(0x7b1, 0x7b2) +(0x200f, 0x2010) +(0xfb1d, 0xfb1e) +(0xfb1f, 0xfb29) +(0xfb2a, 0xfb37) +(0xfb38, 0xfb3d) +(0xfb3e, 0xfb3f) +(0xfb40, 0xfb42) +(0xfb43, 0xfb45) +(0xfb46, 0xfbb2) +(0xfbd3, 0xfd3e) +(0xfd50, 0xfd90) +(0xfd92, 0xfdc8) +(0xfdf0, 0xfdfd) +(0xfe70, 0xfe75) +(0xfe76, 0xfefd) +); +super::stringprep::process_ranges!(table_d2 => +(0x41, 0x5b) +(0x61, 0x7b) +(0xaa, 0xab) +(0xb5, 0xb6) +(0xba, 0xbb) +(0xc0, 0xd7) +(0xd8, 0xf7) +(0xf8, 0x221) +(0x222, 0x234) +(0x250, 0x2ae) +(0x2b0, 0x2b9) +(0x2bb, 0x2c2) +(0x2d0, 0x2d2) +(0x2e0, 0x2e5) +(0x2ee, 0x2ef) +(0x37a, 0x37b) +(0x386, 0x387) +(0x388, 0x38b) +(0x38c, 0x38d) +(0x38e, 0x3a2) +(0x3a3, 0x3cf) +(0x3d0, 0x3f6) +(0x400, 0x483) +(0x48a, 0x4cf) +(0x4d0, 0x4f6) +(0x4f8, 0x4fa) +(0x500, 0x510) +(0x531, 0x557) +(0x559, 0x560) +(0x561, 0x588) +(0x589, 0x58a) +(0x903, 0x904) +(0x905, 0x93a) +(0x93d, 0x941) +(0x949, 0x94d) +(0x950, 0x951) +(0x958, 0x962) +(0x964, 0x971) +(0x982, 0x984) +(0x985, 0x98d) +(0x98f, 0x991) +(0x993, 0x9a9) +(0x9aa, 0x9b1) +(0x9b2, 0x9b3) +(0x9b6, 0x9ba) +(0x9be, 0x9c1) +(0x9c7, 0x9c9) +(0x9cb, 0x9cd) +(0x9d7, 0x9d8) +(0x9dc, 0x9de) +(0x9df, 0x9e2) +(0x9e6, 0x9f2) +(0x9f4, 0x9fb) +(0xa05, 0xa0b) +(0xa0f, 0xa11) +(0xa13, 0xa29) +(0xa2a, 0xa31) +(0xa32, 0xa34) +(0xa35, 0xa37) +(0xa38, 0xa3a) +(0xa3e, 0xa41) +(0xa59, 0xa5d) +(0xa5e, 0xa5f) +(0xa66, 0xa70) +(0xa72, 0xa75) +(0xa83, 0xa84) +(0xa85, 0xa8c) +(0xa8d, 0xa8e) +(0xa8f, 0xa92) +(0xa93, 0xaa9) +(0xaaa, 0xab1) +(0xab2, 0xab4) +(0xab5, 0xaba) +(0xabd, 0xac1) +(0xac9, 0xaca) +(0xacb, 0xacd) +(0xad0, 0xad1) +(0xae0, 0xae1) +(0xae6, 0xaf0) +(0xb02, 0xb04) +(0xb05, 0xb0d) +(0xb0f, 0xb11) +(0xb13, 0xb29) +(0xb2a, 0xb31) +(0xb32, 0xb34) +(0xb36, 0xb3a) +(0xb3d, 0xb3f) +(0xb40, 0xb41) +(0xb47, 0xb49) +(0xb4b, 0xb4d) +(0xb57, 0xb58) +(0xb5c, 0xb5e) +(0xb5f, 0xb62) +(0xb66, 0xb71) +(0xb83, 0xb84) +(0xb85, 0xb8b) +(0xb8e, 0xb91) +(0xb92, 0xb96) +(0xb99, 0xb9b) +(0xb9c, 0xb9d) +(0xb9e, 0xba0) +(0xba3, 0xba5) +(0xba8, 0xbab) +(0xbae, 0xbb6) +(0xbb7, 0xbba) +(0xbbe, 0xbc0) +(0xbc1, 0xbc3) +(0xbc6, 0xbc9) +(0xbca, 0xbcd) +(0xbd7, 0xbd8) +(0xbe7, 0xbf3) +(0xc01, 0xc04) +(0xc05, 0xc0d) +(0xc0e, 0xc11) +(0xc12, 0xc29) +(0xc2a, 0xc34) +(0xc35, 0xc3a) +(0xc41, 0xc45) +(0xc60, 0xc62) +(0xc66, 0xc70) +(0xc82, 0xc84) +(0xc85, 0xc8d) +(0xc8e, 0xc91) +(0xc92, 0xca9) +(0xcaa, 0xcb4) +(0xcb5, 0xcba) +(0xcbe, 0xcbf) +(0xcc0, 0xcc5) +(0xcc7, 0xcc9) +(0xcca, 0xccc) +(0xcd5, 0xcd7) +(0xcde, 0xcdf) +(0xce0, 0xce2) +(0xce6, 0xcf0) +(0xd02, 0xd04) +(0xd05, 0xd0d) +(0xd0e, 0xd11) +(0xd12, 0xd29) +(0xd2a, 0xd3a) +(0xd3e, 0xd41) +(0xd46, 0xd49) +(0xd4a, 0xd4d) +(0xd57, 0xd58) +(0xd60, 0xd62) +(0xd66, 0xd70) +(0xd82, 0xd84) +(0xd85, 0xd97) +(0xd9a, 0xdb2) +(0xdb3, 0xdbc) +(0xdbd, 0xdbe) +(0xdc0, 0xdc7) +(0xdcf, 0xdd2) +(0xdd8, 0xde0) +(0xdf2, 0xdf5) +(0xe01, 0xe31) +(0xe32, 0xe34) +(0xe40, 0xe47) +(0xe4f, 0xe5c) +(0xe81, 0xe83) +(0xe84, 0xe85) +(0xe87, 0xe89) +(0xe8a, 0xe8b) +(0xe8d, 0xe8e) +(0xe94, 0xe98) +(0xe99, 0xea0) +(0xea1, 0xea4) +(0xea5, 0xea6) +(0xea7, 0xea8) +(0xeaa, 0xeac) +(0xead, 0xeb1) +(0xeb2, 0xeb4) +(0xebd, 0xebe) +(0xec0, 0xec5) +(0xec6, 0xec7) +(0xed0, 0xeda) +(0xedc, 0xede) +(0xf00, 0xf18) +(0xf1a, 0xf35) +(0xf36, 0xf37) +(0xf38, 0xf39) +(0xf3e, 0xf48) +(0xf49, 0xf6b) +(0xf7f, 0xf80) +(0xf85, 0xf86) +(0xf88, 0xf8c) +(0xfbe, 0xfc6) +(0xfc7, 0xfcd) +(0xfcf, 0xfd0) +(0x1000, 0x1022) +(0x1023, 0x1028) +(0x1029, 0x102b) +(0x102c, 0x102d) +(0x1031, 0x1032) +(0x1038, 0x1039) +(0x1040, 0x1058) +(0x10a0, 0x10c6) +(0x10d0, 0x10f9) +(0x10fb, 0x10fc) +(0x1100, 0x115a) +(0x115f, 0x11a3) +(0x11a8, 0x11fa) +(0x1200, 0x1207) +(0x1208, 0x1247) +(0x1248, 0x1249) +(0x124a, 0x124e) +(0x1250, 0x1257) +(0x1258, 0x1259) +(0x125a, 0x125e) +(0x1260, 0x1287) +(0x1288, 0x1289) +(0x128a, 0x128e) +(0x1290, 0x12af) +(0x12b0, 0x12b1) +(0x12b2, 0x12b6) +(0x12b8, 0x12bf) +(0x12c0, 0x12c1) +(0x12c2, 0x12c6) +(0x12c8, 0x12cf) +(0x12d0, 0x12d7) +(0x12d8, 0x12ef) +(0x12f0, 0x130f) +(0x1310, 0x1311) +(0x1312, 0x1316) +(0x1318, 0x131f) +(0x1320, 0x1347) +(0x1348, 0x135b) +(0x1361, 0x137d) +(0x13a0, 0x13f5) +(0x1401, 0x1677) +(0x1681, 0x169b) +(0x16a0, 0x16f1) +(0x1700, 0x170d) +(0x170e, 0x1712) +(0x1720, 0x1732) +(0x1735, 0x1737) +(0x1740, 0x1752) +(0x1760, 0x176d) +(0x176e, 0x1771) +(0x1780, 0x17b7) +(0x17be, 0x17c6) +(0x17c7, 0x17c9) +(0x17d4, 0x17db) +(0x17dc, 0x17dd) +(0x17e0, 0x17ea) +(0x1810, 0x181a) +(0x1820, 0x1878) +(0x1880, 0x18a9) +(0x1e00, 0x1e9c) +(0x1ea0, 0x1efa) +(0x1f00, 0x1f16) +(0x1f18, 0x1f1e) +(0x1f20, 0x1f46) +(0x1f48, 0x1f4e) +(0x1f50, 0x1f58) +(0x1f59, 0x1f5a) +(0x1f5b, 0x1f5c) +(0x1f5d, 0x1f5e) +(0x1f5f, 0x1f7e) +(0x1f80, 0x1fb5) +(0x1fb6, 0x1fbd) +(0x1fbe, 0x1fbf) +(0x1fc2, 0x1fc5) +(0x1fc6, 0x1fcd) +(0x1fd0, 0x1fd4) +(0x1fd6, 0x1fdc) +(0x1fe0, 0x1fed) +(0x1ff2, 0x1ff5) +(0x1ff6, 0x1ffd) +(0x200e, 0x200f) +(0x2071, 0x2072) +(0x207f, 0x2080) +(0x2102, 0x2103) +(0x2107, 0x2108) +(0x210a, 0x2114) +(0x2115, 0x2116) +(0x2119, 0x211e) +(0x2124, 0x2125) +(0x2126, 0x2127) +(0x2128, 0x2129) +(0x212a, 0x212e) +(0x212f, 0x2132) +(0x2133, 0x213a) +(0x213d, 0x2140) +(0x2145, 0x214a) +(0x2160, 0x2184) +(0x2336, 0x237b) +(0x2395, 0x2396) +(0x249c, 0x24ea) +(0x3005, 0x3008) +(0x3021, 0x302a) +(0x3031, 0x3036) +(0x3038, 0x303d) +(0x3041, 0x3097) +(0x309d, 0x30a0) +(0x30a1, 0x30fb) +(0x30fc, 0x3100) +(0x3105, 0x312d) +(0x3131, 0x318f) +(0x3190, 0x31b8) +(0x31f0, 0x321d) +(0x3220, 0x3244) +(0x3260, 0x327c) +(0x327f, 0x32b1) +(0x32c0, 0x32cc) +(0x32d0, 0x32ff) +(0x3300, 0x3377) +(0x337b, 0x33de) +(0x33e0, 0x33ff) +(0x3400, 0x4db6) +(0x4e00, 0x9fa6) +(0xa000, 0xa48d) +(0xac00, 0xd7a4) +(0xd800, 0xfa2e) +(0xfa30, 0xfa6b) +(0xfb00, 0xfb07) +(0xfb13, 0xfb18) +(0xff21, 0xff3b) +(0xff41, 0xff5b) +(0xff66, 0xffbf) +(0xffc2, 0xffc8) +(0xffca, 0xffd0) +(0xffd2, 0xffd8) +(0xffda, 0xffdd) +(0x10300, 0x1031f) +(0x10320, 0x10324) +(0x10330, 0x1034b) +(0x10400, 0x10426) +(0x10428, 0x1044e) +(0x1d000, 0x1d0f6) +(0x1d100, 0x1d127) +(0x1d12a, 0x1d167) +(0x1d16a, 0x1d173) +(0x1d183, 0x1d185) +(0x1d18c, 0x1d1aa) +(0x1d1ae, 0x1d1de) +(0x1d400, 0x1d455) +(0x1d456, 0x1d49d) +(0x1d49e, 0x1d4a0) +(0x1d4a2, 0x1d4a3) +(0x1d4a5, 0x1d4a7) +(0x1d4a9, 0x1d4ad) +(0x1d4ae, 0x1d4ba) +(0x1d4bb, 0x1d4bc) +(0x1d4bd, 0x1d4c1) +(0x1d4c2, 0x1d4c4) +(0x1d4c5, 0x1d506) +(0x1d507, 0x1d50b) +(0x1d50d, 0x1d515) +(0x1d516, 0x1d51d) +(0x1d51e, 0x1d53a) +(0x1d53b, 0x1d53f) +(0x1d540, 0x1d545) +(0x1d546, 0x1d547) +(0x1d54a, 0x1d551) +(0x1d552, 0x1d6a4) +(0x1d6a8, 0x1d7ca) +(0x20000, 0x2a6d7) +(0x2f800, 0x2fa1e) +(0xf0000, 0xffffe) +(0x100000, 0x10fffe) +); diff --git a/edb/server/pgrust/src/auth/stringprep_table_prep.py b/edb/server/pgrust/src/auth/stringprep_table_prep.py new file mode 100644 index 000000000000..db125be2a5cb --- /dev/null +++ b/edb/server/pgrust/src/auth/stringprep_table_prep.py @@ -0,0 +1,43 @@ +import stringprep + +SASLPREP_PROHIBITED = ( + stringprep.in_table_a1, # PostgreSQL treats this as prohibited + stringprep.in_table_c12, + stringprep.in_table_c21_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9, +) + + +def gen(name, f): + r = None + print(f"super::stringprep::process_ranges!({name} =>") + for c in range(0, 0x110000): + c = chr(c) + prohibited = f(c) + if prohibited and r is None: + r = ord(c) + if not prohibited and r is not None: + print(f"(0x{r:x}, 0x{ord(c):x})") + r = None + if r: + print(f"0x{r:x}") + print(");") + + +gen("not_prohibited", lambda c: not any( + in_prohibited_table(c) + for in_prohibited_table in SASLPREP_PROHIBITED + )) + +gen("maps_to_space", lambda c: stringprep.in_table_c12(c)) + +gen("maps_to_nothing", lambda c: stringprep.in_table_b1(c)) + +gen("table_d1", lambda c: stringprep.in_table_d1(c)) +gen("table_d2", lambda c: stringprep.in_table_d2(c)) diff --git a/edb/server/pgrust/src/conn_string.rs b/edb/server/pgrust/src/conn_string.rs deleted file mode 100644 index b749fbfced27..000000000000 --- a/edb/server/pgrust/src/conn_string.rs +++ /dev/null @@ -1,978 +0,0 @@ -use itertools::Itertools; -use percent_encoding::percent_decode_str; -use serde_derive::Serialize; -use std::borrow::Cow; -use std::collections::HashMap; -use std::fs::OpenOptions; -use std::io::ErrorKind; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use std::path::{Path, PathBuf}; -use std::str::FromStr; -use std::time::Duration; -use thiserror::Error; -use url::Url; - -#[derive(Error, Debug, PartialEq, Eq)] -#[allow(clippy::enum_variant_names)] -pub enum ParseError { - #[error( - "Invalid DSN: scheme is expected to be either \"postgresql\" or \"postgres\", got {0}" - )] - InvalidScheme(String), - - #[error("Invalid value for parameter \"{0}\": \"{1}\"")] - InvalidParameter(String, String), - - #[error("Invalid port: \"{0}\"")] - InvalidPort(String), - - #[error("Unexpected number of ports, must be either a single port or the same number as the host count: \"{0}\"")] - InvalidPortCount(String), - - #[error("Invalid hostname: \"{0}\"")] - InvalidHostname(String), - - #[error("Could not determine the connection {0}")] - MissingRequiredParameter(String), - - #[error("URL parse error: {0}")] - UrlParseError(#[from] url::ParseError), -} - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] -pub enum Host { - Hostname(String, u16), - IP(IpAddr, u16, Option), - Path(String, u16), - Abstract(String, u16), -} - -#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] -pub enum Password { - /// The password is unspecified and should be read from the user's default - /// passfile if it exists. - #[default] - Unspecified, - /// The password was specified. - Specified(String), - /// The passfile is specified. - Passfile(PathBuf), -} - -#[derive(Serialize)] -pub enum PasswordWarning { - NotFile(PathBuf), - NotExists(PathBuf), - NotAccessible(PathBuf), - Permissions(PathBuf, u32), -} - -impl std::fmt::Display for PasswordWarning { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - PasswordWarning::NotFile(path) => write!(f, "Password file {path:?} is not a plain file"), - PasswordWarning::NotExists(path) => write!(f, "Password file {path:?} does not exist"), - PasswordWarning::NotAccessible(path) => write!(f, "Password file {path:?} is not accessible"), - PasswordWarning::Permissions(path, mode) => write!(f, "Password file {path:?} has group or world access ({mode:o}); permissions should be u=rw (0600) or less"), - } - } -} - -#[cfg(windows)] -const PGPASSFILE: &str = "pgpass.conf"; -#[cfg(not(windows))] -const PGPASSFILE: &str = ".pgpass"; - -impl Password { - pub fn password(&self) -> Option<&str> { - match self { - Password::Specified(password) => Some(password), - _ => None, - } - } - - /// Attempt to resolve a password against the given homedir. - pub fn resolve( - &mut self, - home: &Path, - hosts: &[Host], - database: &str, - user: &str, - ) -> Result, std::io::Error> { - let passfile = match self { - Password::Unspecified => { - let passfile = home.join(PGPASSFILE); - // Don't warn about implicit missing or inaccessible files - if !matches!(passfile.try_exists(), Ok(true)) { - *self = Password::Unspecified; - return Ok(None); - } - if !passfile.is_file() { - *self = Password::Unspecified; - return Ok(None); - } - passfile - } - Password::Specified(_) => return Ok(None), - Password::Passfile(passfile) => { - let passfile = passfile.clone(); - if matches!(passfile.try_exists(), Ok(false)) { - *self = Password::Unspecified; - return Ok(Some(PasswordWarning::NotExists(passfile))); - } - if passfile.exists() && !passfile.is_file() { - *self = Password::Unspecified; - return Ok(Some(PasswordWarning::NotFile(passfile))); - } - passfile - } - }; - - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - - let metadata = match passfile.metadata() { - Err(err) if err.kind() == ErrorKind::PermissionDenied => { - *self = Password::Unspecified; - return Ok(Some(PasswordWarning::NotAccessible(passfile))); - } - res => res?, - }; - let permissions = metadata.permissions(); - let mode = permissions.mode(); - - if mode & (0o070) != 0 { - *self = Password::Unspecified; - return Ok(Some(PasswordWarning::Permissions(passfile, mode))); - } - } - - let file = match OpenOptions::new().read(true).open(&passfile) { - Err(err) if err.kind() == ErrorKind::PermissionDenied => { - *self = Password::Unspecified; - return Ok(Some(PasswordWarning::NotAccessible(passfile))); - } - res => res?, - }; - if let Some(password) = read_password_file( - hosts, - database, - user, - std::io::read_to_string(file)?.split('\n'), - ) { - *self = Password::Specified(password); - } else { - *self = Password::Unspecified; - } - Ok(None) - } -} - -#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)] -pub struct ConnectionParameters { - pub hosts: Vec, - pub database: String, - pub user: String, - pub password: Password, - pub connect_timeout: Option, - pub server_settings: HashMap, - pub ssl: Ssl, -} - -#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)] -#[allow(clippy::large_enum_variant)] -pub enum Ssl { - #[default] - Disable, - Enable(SslMode, SslParameters), -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize)] -pub enum SslMode { - #[serde(rename = "allow")] - Allow, - #[serde(rename = "prefer")] - Prefer, - #[serde(rename = "require")] - Require, - #[serde(rename = "verify_ca")] - VerifyCA, - #[serde(rename = "verify_full")] - VerifyFull, -} - -#[derive(Default, Clone, Debug, PartialEq, Eq, Serialize)] -pub struct SslParameters { - pub cert: Option, - pub key: Option, - pub password: Option, - pub rootcert: Option, - pub crl: Option, - pub min_protocol_version: Option, - pub max_protocol_version: Option, - pub keylog_filename: Option, - pub verify_crl_check_chain: Option, -} - -#[derive(Default, Debug, Serialize)] -pub struct SslPaths { - pub rootcert: Option, - pub crl: Option, - pub key: Option, - pub cert: Option, -} - -impl Ssl { - /// Resolve the SSL paths relative to the home directory. - pub fn resolve(&self, home_dir: &Path) -> Result { - let postgres_dir = home_dir; - let Ssl::Enable(mode, params) = self else { - return Ok(SslPaths::default()); - }; - let mut paths = SslPaths::default(); - if *mode >= SslMode::Require { - let root_cert = params - .rootcert - .clone() - .unwrap_or_else(|| postgres_dir.join("root.crt")); - if root_cert.exists() { - paths.rootcert = Some(root_cert); - } else if *mode > SslMode::Require { - return Err(std::io::Error::new(ErrorKind::NotFound, - format!("Root certificate not found: {root_cert:?}. Either provide the file or change sslmode to disable SSL certificate verification."))); - } - - let crl = params - .crl - .clone() - .unwrap_or_else(|| postgres_dir.join("root.crl")); - if crl.exists() { - paths.crl = Some(crl); - } - } - let key = params - .key - .clone() - .unwrap_or_else(|| postgres_dir.join("postgresql.key")); - if key.exists() { - paths.key = Some(key); - } - let cert = params - .cert - .clone() - .unwrap_or_else(|| postgres_dir.join("postgresql.crt")); - if cert.exists() { - paths.cert = Some(cert); - } - Ok(paths) - } -} - -pub trait EnvVar { - fn read(&self, name: &'static str) -> Option>; -} - -impl EnvVar for HashMap -where - K: std::hash::Hash + Eq + std::borrow::Borrow, - V: std::borrow::Borrow, -{ - fn read(&self, name: &'static str) -> Option> { - self.get(name).map(|value| value.borrow().into()) - } -} - -impl EnvVar for std::env::Vars { - fn read(&self, name: &'static str) -> Option> { - if let Ok(value) = std::env::var(name) { - Some(value.into()) - } else { - None - } - } -} - -impl EnvVar for &[(&str, &str)] { - fn read(&self, name: &'static str) -> Option> { - for (key, value) in self.iter() { - if *key == name { - return Some((*value).into()); - } - } - None - } -} - -impl EnvVar for () { - fn read(&self, _: &'static str) -> Option> { - None - } -} - -fn maybe_decode(str: Cow) -> Cow { - if str.contains('%') { - if let Ok(str) = percent_decode_str(&str).decode_utf8() { - str.into_owned().into() - } else { - str.into_owned().into() - } - } else { - str - } -} - -fn parse_port(port: &str) -> Result { - if port.contains('%') { - let decoded = percent_decode_str(port) - .decode_utf8() - .map_err(|_| ParseError::InvalidPort(port.to_string()))?; - decoded - .parse::() - .map_err(|_| ParseError::InvalidPort(port.to_string())) - } else { - port.parse::() - .map_err(|_| ParseError::InvalidPort(port.to_string())) - } -} - -fn parse_hostlist( - hostspecs: &[&str], - mut specified_ports: &[u16], -) -> Result, ParseError> { - let mut hosts = vec![]; - - if specified_ports.is_empty() { - specified_ports = &[5432]; - } else if specified_ports.len() != hostspecs.len() && specified_ports.len() > 1 { - return Err(ParseError::InvalidPortCount(format!("{specified_ports:?}"))); - } - - for (i, hostspec) in hostspecs.iter().enumerate() { - let port = specified_ports[i % specified_ports.len()]; - - let host = if hostspec.starts_with('/') { - Host::Path(hostspec.to_string(), port) - } else if hostspec.starts_with('[') { - // Handling IPv6 address - let end_bracket = hostspec - .find(']') - .ok_or_else(|| ParseError::InvalidHostname(hostspec.to_string()))?; - - // Extract interface (optional) after % - let (interface, ipv6_part, port_part) = if let Some(pos) = hostspec.find('%') { - ( - Some(hostspec[pos + 1..end_bracket].to_string()), - &hostspec[1..pos], - &hostspec[end_bracket + 1..], - ) - } else { - ( - None, - &hostspec[1..end_bracket], - &hostspec[end_bracket + 1..], - ) - }; - let addr = Ipv6Addr::from_str(ipv6_part) - .map_err(|_| ParseError::InvalidHostname(hostspec.to_string()))?; - - let port = if let Some(stripped) = port_part.strip_prefix(':') { - parse_port(stripped)? - } else { - port - }; - Host::IP(IpAddr::V6(addr), port, interface) - } else { - let parts: Vec<&str> = hostspec.split(':').collect(); - let addr = parts[0].to_string(); - let port = if parts.len() > 1 { - parse_port(parts[1])? - } else { - port - }; - - if let Ok(ip) = Ipv4Addr::from_str(&addr) { - Host::IP(IpAddr::V4(ip), port, None) - } else { - Host::Hostname(addr, port) - } - }; - - hosts.push(host) - } - Ok(hosts) -} - -pub fn parse_postgres_url( - url_str: &str, - env: impl EnvVar, -) -> Result { - let url_str = if let Some(url) = url_str.strip_prefix("postgres://") { - url - } else if let Some(url) = url_str.strip_prefix("postgresql://") { - url - } else { - return Err(ParseError::InvalidScheme( - url_str.split(':').next().unwrap_or_default().to_owned(), - )); - }; - - let path_or_query = url_str.find(|c| c == '?' || c == '/'); - let (authority, path_and_query) = match path_or_query { - Some(index) => url_str.split_at(index), - None => (url_str, ""), - }; - - let (auth, host) = match authority.split_once('@') { - Some((auth, host)) => (auth, host), - None => ("", authority), - }; - - let url = Url::parse(&format!("unused://{auth}@host{path_and_query}"))?; - - let mut server_settings = HashMap::new(); - let mut host: Option> = if host.is_empty() { - None - } else { - Some(host.into()) - } - .map(maybe_decode); - let mut port = None; - - let mut user: Option> = match url.username() { - "" => None, - user => { - let decoded = percent_decode_str(user); - if let Ok(user) = decoded.decode_utf8() { - Some(user) - } else { - Some(user.into()) - } - } - }; - let mut password: Option> = url.password().map(|p| p.into()).map(maybe_decode); - let mut database: Option> = match url.path() { - "" | "/" => None, - path => Some(path.trim_start_matches('/').into()), - } - .map(maybe_decode); - - let mut passfile = None; - let mut connect_timeout = None; - - let mut sslmode = None; - let mut sslcert = None; - let mut sslkey = None; - let mut sslpassword = None; - let mut sslrootcert = None; - let mut sslcrl = None; - let mut ssl_min_protocol_version = None; - let mut ssl_max_protocol_version = None; - - for (name, value) in url.query_pairs() { - match name.as_ref() { - "host" => { - if host.is_none() { - host = Some(value); - } - } - "port" => { - if port.is_none() { - port = Some( - value - .split(',') - .map(parse_port) - .collect::, _>>()?, - ); - } - } - "dbname" | "database" => { - if database.is_none() { - database = Some(value); - } - } - "user" => { - if user.is_none() { - user = Some(value); - } - } - "password" => { - if password.is_none() { - password = Some(value); - } - } - "passfile" => passfile = Some(value), - "connect_timeout" => connect_timeout = Some(value), - - "sslmode" => sslmode = Some(value), - "sslcert" => sslcert = Some(value), - "sslkey" => sslkey = Some(value), - "sslpassword" => sslpassword = Some(value), - "sslrootcert" => sslrootcert = Some(value), - "sslcrl" => sslcrl = Some(value), - "ssl_min_protocol_version" => ssl_min_protocol_version = Some(value), - "ssl_max_protocol_version" => ssl_max_protocol_version = Some(value), - - name => { - server_settings.insert(name.to_string(), value.to_string()); - } - }; - } - - if host.is_none() { - host = env.read("PGHOST"); - } - if port.is_none() { - if let Some(value) = env.read("PGPORT") { - port = Some( - value - .split(',') - .map(parse_port) - .collect::, _>>()?, - ); - } - } - - if host.is_none() { - host = Some("/run/postgresql,/var/run/postgresql,/tmp,/private/tmp,localhost".into()); - } - let host = host - .as_ref() - .map(|s| s.split(',').collect_vec()) - .unwrap_or_default(); - let hosts = parse_hostlist(&host, port.as_deref().unwrap_or_default())?; - - if hosts.is_empty() { - return Err(ParseError::MissingRequiredParameter("address".to_string())); - } - - if user.is_none() { - user = env.read("PGUSER"); - } - if password.is_none() { - password = env.read("PGPASSWORD"); - } - if database.is_none() { - database = env.read("PGDATABASE"); - } - if database.is_none() { - database = user.clone(); - } - - let Some(user) = user else { - return Err(ParseError::MissingRequiredParameter("user".to_string())); - }; - let Some(database) = database else { - return Err(ParseError::MissingRequiredParameter("database".to_string())); - }; - - let password = match password { - Some(p) => Password::Specified(p.into_owned()), - None => { - if let Some(passfile) = passfile.or_else(|| env.read("PGPASSFILE")) { - Password::Passfile(passfile.into_owned().into()) - } else { - Password::Unspecified - } - } - }; - - if connect_timeout.is_none() { - connect_timeout = env.read("PGCONNECT_TIMEOUT"); - } - - // Match the same behavior of libpq - // https://www.postgresql.org/docs/current/libpq-connect.html - let connect_timeout = match connect_timeout { - None => None, - Some(s) => { - let seconds = s.parse::().map_err(|_| { - ParseError::InvalidParameter("connect_timeout".to_string(), s.to_string()) - })?; - if seconds <= 0 { - None - } else { - Some(Duration::from_secs(seconds.max(2) as _)) - } - } - }; - - let any_tcp = hosts - .iter() - .any(|host| matches!(host, Host::Hostname(..) | Host::IP(..))); - - if sslmode.is_none() { - sslmode = env.read("PGSSLMODE"); - } - - if sslmode.is_none() && any_tcp { - sslmode = Some("prefer".into()); - } - - let ssl = if let Some(sslmode) = sslmode { - if sslmode == "disable" { - Ssl::Disable - } else { - let sslmode = match sslmode.as_ref() { - "allow" => SslMode::Allow, - "prefer" => SslMode::Prefer, - "require" => SslMode::Require, - "verify_ca" | "verify-ca" => SslMode::VerifyCA, - "verify_full" | "verify-full" => SslMode::VerifyFull, - _ => { - return Err(ParseError::InvalidParameter( - "sslmode".to_string(), - sslmode.to_string(), - )) - } - }; - let mut ssl = SslParameters::default(); - if sslmode >= SslMode::Require { - if sslrootcert.is_none() { - sslrootcert = env.read("PGSSLROOTCERT"); - } - ssl.rootcert = sslrootcert.map(|s| PathBuf::from(s.into_owned())); - if sslcrl.is_none() { - sslcrl = env.read("PGSSLCRL"); - } - ssl.crl = sslcrl.map(|s| PathBuf::from(s.into_owned())); - } - if sslkey.is_none() { - sslkey = env.read("PGSSLKEY"); - } - ssl.key = sslkey.map(|s| PathBuf::from(s.into_owned())); - if sslcert.is_none() { - sslcert = env.read("PGSSLCERT"); - } - ssl.cert = sslcert.map(|s| PathBuf::from(s.into_owned())); - if ssl_min_protocol_version.is_none() { - ssl_min_protocol_version = env.read("PGSSLMINPROTOCOLVERSION"); - } - ssl.min_protocol_version = ssl_min_protocol_version.map(|s| s.into_owned()); - if ssl_max_protocol_version.is_none() { - ssl_max_protocol_version = env.read("PGSSLMAXPROTOCOLVERSION"); - } - ssl.max_protocol_version = ssl_max_protocol_version.map(|s| s.into_owned()); - - // There is no environment variable equivalent to this option - ssl.password = sslpassword.map(|s| s.into_owned()); - - Ssl::Enable(sslmode, ssl) - } - } else { - Ssl::Disable - }; - - Ok(ConnectionParameters { - hosts, - database: database.into_owned(), - user: user.into_owned(), - password, - connect_timeout, - server_settings, - ssl, - }) -} - -fn read_password_file( - hosts: &[Host], - database: &str, - user: &str, - reader: impl Iterator>, -) -> Option { - for line in reader { - let line = line.as_ref().trim(); - - if line.is_empty() || line.starts_with('#') { - continue; - } - - let mut parts = vec![String::new()]; - let mut backslash = false; - for c in line.chars() { - if backslash { - parts.last_mut().unwrap().push(c); - backslash = false; - continue; - } - if c == '\\' { - backslash = true; - continue; - } - if c == ':' && parts.len() <= 4 { - parts.push(String::new()); - continue; - } - parts.last_mut().unwrap().push(c); - } - - if parts.len() == 5 { - for host in hosts { - let port = match host { - Host::Hostname(hostname, port) => { - if parts[0] != "*" && parts[0] != hostname.as_str() { - continue; - } - *port - } - Host::IP(hostname, port, _) => { - if parts[0] != "*" && str::parse(&parts[0]) != Ok(*hostname) { - continue; - } - *port - } - Host::Path(_, port) | Host::Abstract(_, port) => { - if parts[0] != "*" && parts[0] != "localhost" { - continue; - } - *port - } - }; - if parts[1] != "*" && str::parse(&parts[1]) != Ok(port) { - continue; - } - if parts[2] != "*" && parts[2] != database { - continue; - } - if parts[3] != "*" && parts[3] != user { - continue; - } - return Some(parts.pop().unwrap()); - } - } - } - - None -} - -#[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn test_parse_hostlist() { - assert_eq!( - parse_hostlist(&["hostname"], &[1234]), - Ok(vec![Host::Hostname("hostname".to_string(), 1234)]) - ); - assert_eq!( - parse_hostlist(&["hostname:4321"], &[1234]), - Ok(vec![Host::Hostname("hostname".to_string(), 4321)]) - ); - assert_eq!( - parse_hostlist(&["/path"], &[1234]), - Ok(vec![Host::Path("/path".to_string(), 1234)]) - ); - assert_eq!( - parse_hostlist(&["[2001:db8::1234]", "[::1]"], &[1234]), - Ok(vec![ - Host::IP( - IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x1234)), - 1234, - None - ), - Host::IP(IpAddr::V6(Ipv6Addr::LOCALHOST), 1234, None), - ]) - ); - assert_eq!( - parse_hostlist(&["[2001:db8::1234%eth0]"], &[1234]), - Ok(vec![Host::IP( - IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x1234)), - 1234, - Some("eth0".to_string()) - ),]) - ); - } - - #[test] - fn test_parse_password_file() { - let input = r#" -abc:*:*:user:password from pgpass for user@abc -localhost:*:*:*:password from pgpass for localhost -cde:5433:*:*:password from pgpass for cde:5433 - -*:*:*:testuser:password from pgpass for testuser -*:*:testdb:*:password from pgpass for testdb -# comment -*:*:test\:db:test\\:password from pgpass with escapes - "# - .trim(); - - for (host, database, user, output) in [ - ( - Host::Hostname("abc".to_owned(), 1234), - "database", - "user", - Some("password from pgpass for user@abc"), - ), - ( - Host::Hostname("localhost".to_owned(), 1234), - "database", - "user", - Some("password from pgpass for localhost"), - ), - ( - Host::Path("/tmp".into(), 1234), - "database", - "user", - Some("password from pgpass for localhost"), - ), - ( - Host::Hostname("hmm".to_owned(), 1234), - "database", - "testuser", - Some("password from pgpass for testuser"), - ), - ( - Host::Hostname("hostname".to_owned(), 1234), - "test:db", - r#"test\"#, - Some("password from pgpass with escapes"), - ), - ( - Host::Hostname("doesntexist".to_owned(), 1234), - "db", - "user", - None, - ), - ] { - assert_eq!( - read_password_file(&[host], database, user, input.split('\n')), - output.map(|s| s.to_owned()) - ); - } - } - - #[test] - fn test_parse_dsn() { - assert_eq!( - parse_postgres_url( - "postgres://", - [ - ("PGUSER", "user"), - ("PGDATABASE", "testdb"), - ("PGPASSWORD", "passw"), - ("PGHOST", "host"), - ("PGPORT", "123"), - ("PGCONNECT_TIMEOUT", "8"), - ] - .as_slice() - ) - .unwrap(), - ConnectionParameters { - hosts: vec![Host::Hostname("host".to_string(), 123,),], - database: "testdb".to_string(), - user: "user".to_string(), - password: Password::Specified("passw".to_string(),), - connect_timeout: Some(Duration::from_secs(8)), - ssl: Ssl::Enable(SslMode::Prefer, Default::default()), - ..Default::default() - } - ); - - assert_eq!( - parse_postgres_url("postgres://user:pass@host:1234/database", ()).unwrap(), - ConnectionParameters { - hosts: vec![Host::Hostname("host".to_string(), 1234,),], - database: "database".to_string(), - user: "user".to_string(), - password: Password::Specified("pass".to_string(),), - ssl: Ssl::Enable(SslMode::Prefer, Default::default()), - ..Default::default() - } - ); - - assert_eq!( - parse_postgres_url("postgresql://user@host1:1111,host2:2222/db", ()).unwrap(), - ConnectionParameters { - hosts: vec![ - Host::Hostname("host1".to_string(), 1111,), - Host::Hostname("host2".to_string(), 2222,), - ], - database: "db".to_string(), - user: "user".to_string(), - password: Password::Unspecified, - ssl: Ssl::Enable(SslMode::Prefer, Default::default()), - ..Default::default() - } - ); - } - - #[test] - fn test_dsn_with_slashes() { - assert_eq!( - parse_postgres_url( - r#"postgres://test\\@fgh/test\:db?passfile=/tmp/tmpkrjuaje4"#, - () - ) - .unwrap(), - ConnectionParameters { - hosts: vec![Host::Hostname("fgh".to_string(), 5432,),], - database: r#"test\:db"#.to_string(), - user: r#"test\\"#.to_string(), - password: Password::Passfile("/tmp/tmpkrjuaje4".to_string().into(),), - ssl: Ssl::Enable(SslMode::Prefer, Default::default()), - ..Default::default() - } - ); - } - - #[test] - fn test_dns_with_params() { - assert_eq!(parse_postgres_url("postgresql://me:ask@127.0.0.1:888/db?param=sss¶m=123&host=testhost&user=testuser&port=2222&database=testdb&sslmode=verify_full&aa=bb", ()).unwrap(), ConnectionParameters { - hosts: vec![ - Host::IP( - IpAddr::V4(Ipv4Addr::LOCALHOST), - 888, - None, - ), - ], - database: "db".to_string(), - user: "me".to_string(), - password: Password::Specified( - "ask".to_string(), - ), - server_settings: HashMap::from_iter([ - ("aa".to_string(), "bb".to_string()), - ("param".to_string(), "123".to_string()) - ]), - ssl: Ssl::Enable(SslMode::VerifyFull, Default::default()), - ..Default::default() - }) - } - - #[test] - fn test_dsn_with_escapes() { - assert_eq!( - parse_postgres_url("postgresql://us%40r:p%40ss@h%40st1,h%40st2:543%33/d%62", ()) - .unwrap(), - ConnectionParameters { - hosts: vec![ - Host::Hostname("h@st1".to_string(), 5432,), - Host::Hostname("h@st2".to_string(), 5433,), - ], - database: "db".to_string(), - user: "us@r".to_string(), - password: Password::Specified("p@ss".to_string(),), - ssl: Ssl::Enable(SslMode::Prefer, Default::default()), - ..Default::default() - } - ); - } - - #[test] - fn test_dsn_no_slash() { - assert_eq!( - parse_postgres_url("postgres://user@?port=56226&host=%2Ftmp", ()).unwrap(), - ConnectionParameters { - hosts: vec![Host::Path("/tmp".to_string(), 56226,),], - database: "user".to_string(), - user: "user".to_string(), - password: Password::Unspecified, - ssl: Ssl::Disable, - ..Default::default() - } - ); - } -} diff --git a/edb/server/pgrust/src/connection/conn.rs b/edb/server/pgrust/src/connection/conn.rs new file mode 100644 index 000000000000..7c1f777a6a55 --- /dev/null +++ b/edb/server/pgrust/src/connection/conn.rs @@ -0,0 +1,406 @@ +use super::{ + connect_raw_ssl, + raw_conn::RawClient, + stream::{Stream, StreamWithUpgrade}, + ConnectionSslRequirement, Credentials, +}; +use crate::{ + connection::ConnectionError, + protocol::{ + builder, match_message, meta, CommandComplete, DataRow, ErrorResponse, Message, + ReadyForQuery, RowDescription, StructBuffer, + }, +}; +use futures::FutureExt; +use std::{ + cell::RefCell, + pin::Pin, + sync::Arc, + task::{ready, Poll}, +}; +use std::{ + collections::VecDeque, + future::{poll_fn, Future}, + rc::Rc, +}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tracing::{error, trace, warn}; + +#[derive(Debug, thiserror::Error)] +pub enum PGError { + #[error("Invalid state")] + InvalidState, + #[error("Connection failed: {0}")] + Connection(#[from] ConnectionError), + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("Connection was closed")] + Closed, +} + +pub struct Client +where + (B, C): StreamWithUpgrade, +{ + conn: Rc>, +} + +impl Client +where + (B, C): StreamWithUpgrade, + B: 'static, + C: 'static, +{ + pub fn new( + credentials: Credentials, + socket: B, + config: C, + ) -> (Self, impl Future>) { + let conn = Rc::new(PGConn::new_connection(async move { + let ssl_mode = ConnectionSslRequirement::Optional; + let raw = connect_raw_ssl(credentials, ssl_mode, config, socket).await?; + Ok(raw) + })); + let task = conn.clone().task(); + (Self { conn }, task) + } + + /// Create a new PostgreSQL client and a background task. + pub fn new_raw(stm: RawClient) -> (Self, impl Future>) { + let conn = Rc::new(PGConn::new_raw(stm)); + let task = conn.clone().task(); + (Self { conn }, task) + } + + pub async fn ready(&self) -> Result<(), PGError> { + self.conn.ready().await + } + + pub fn query( + &self, + query: &str, + f: impl QuerySink + 'static, + ) -> impl Future> { + self.conn.clone().query(query.to_owned(), f) + } +} + +struct ErasedQuerySink(Q); + +impl QuerySink for ErasedQuerySink +where + Q: QuerySink, + S: DataSink + 'static, +{ + type Output = Box; + fn error(&self, error: ErrorResponse) { + self.0.error(error) + } + fn rows(&self, rows: RowDescription) -> Self::Output { + Box::new(self.0.rows(rows)) + } +} + +pub trait QuerySink { + type Output: DataSink; + fn rows(&self, rows: RowDescription) -> Self::Output; + fn error(&self, error: ErrorResponse); +} + +impl QuerySink for Box +where + Q: QuerySink + 'static, + S: DataSink + 'static, +{ + type Output = Box; + fn rows(&self, rows: RowDescription) -> Self::Output { + Box::new(self.as_ref().rows(rows)) + } + fn error(&self, error: ErrorResponse) { + self.as_ref().error(error) + } +} + +impl QuerySink for (F1, F2) +where + F1: for<'a> Fn(RowDescription) -> S, + F2: for<'a> Fn(ErrorResponse), + S: DataSink, +{ + type Output = S; + fn rows(&self, rows: RowDescription) -> S { + (self.0)(rows) + } + fn error(&self, error: ErrorResponse) { + (self.1)(error) + } +} + +pub trait DataSink { + fn row(&self, values: Result); +} + +impl DataSink for () { + fn row(&self, _: Result) {} +} + +impl DataSink for F +where + F: for<'a> Fn(Result, ErrorResponse<'a>>), +{ + fn row(&self, values: Result) { + (self)(values) + } +} + +impl DataSink for Box { + fn row(&self, values: Result) { + self.as_ref().row(values) + } +} + +struct QueryWaiter { + #[allow(unused)] + tx: tokio::sync::mpsc::UnboundedSender<()>, + f: Box>>, + data: RefCell>>, +} + +#[derive(derive_more::Debug)] +enum ConnState +where + (B, C): StreamWithUpgrade, +{ + #[debug("Connecting(..)")] + #[allow(clippy::type_complexity)] + Connecting(Pin, ConnectionError>>>>), + #[debug("Ready(..)")] + Ready(RawClient, VecDeque), + Error(PGError), + Closed, +} + +struct PGConn +where + (B, C): StreamWithUpgrade, +{ + state: RefCell>, + write_lock: tokio::sync::Mutex<()>, + ready_lock: Arc>, +} + +impl PGConn +where + (B, C): StreamWithUpgrade, +{ + pub fn new_connection( + future: impl Future, ConnectionError>> + 'static, + ) -> Self { + Self { + state: ConnState::Connecting(future.boxed_local()).into(), + write_lock: Default::default(), + ready_lock: Default::default(), + } + } + + pub fn new_raw(stm: RawClient) -> Self { + Self { + state: ConnState::Ready(stm, Default::default()).into(), + write_lock: Default::default(), + ready_lock: Default::default(), + } + } + + fn check_error(&self) -> Result<(), PGError> { + let state = &mut *self.state.borrow_mut(); + match state { + ConnState::Error(..) => { + let ConnState::Error(e) = std::mem::replace(state, ConnState::Closed) else { + unreachable!(); + }; + error!("Connection failed: {e:?}"); + Err(e) + } + ConnState::Closed => Err(PGError::Closed), + _ => Ok(()), + } + } + + #[inline(always)] + async fn ready(&self) -> Result<(), PGError> { + let _ = self.ready_lock.lock().await; + self.check_error() + } + + fn with_stream(&self, f: F) -> Result + where + F: FnOnce(Pin<&mut RawClient>) -> T, + { + match &mut *self.state.borrow_mut() { + ConnState::Ready(ref mut raw_client, _) => Ok(f(Pin::new(raw_client))), + _ => Err(PGError::InvalidState), + } + } + + async fn write(&self, mut buf: &[u8]) -> Result<(), PGError> { + let _lock = self.write_lock.lock().await; + + if buf.is_empty() { + return Ok(()); + } + println!("Write:"); + hexdump::hexdump(buf); + loop { + let n = poll_fn(|cx| { + self.with_stream(|stm| { + let n = match ready!(stm.poll_write(cx, buf)) { + Ok(n) => n, + Err(e) => return Poll::Ready(Err(PGError::Io(e))), + }; + Poll::Ready(Ok(n)) + })? + }) + .await?; + if n == buf.len() { + break; + } + buf = &buf[n..]; + } + Ok(()) + } + + fn process_message(&self, message: Option) -> Result<(), PGError> { + let state = &mut *self.state.borrow_mut(); + match state { + ConnState::Ready(_, queue) => { + let message = message.ok_or(PGError::InvalidState); + match_message!(message?, Backend { + (RowDescription as row) => { + if let Some(qw) = queue.back() { + let qs = qw.f.rows(row); + *qw.data.borrow_mut() = Some(qs); + } + }, + (DataRow as row) => { + if let Some(qw) = queue.back() { + if let Some(qs) = &*qw.data.borrow() { + qs.row(Ok(row)) + } + } + }, + (CommandComplete) => { + if let Some(qw) = queue.back() { + *qw.data.borrow_mut() = None; + } + }, + (ReadyForQuery) => { + queue.pop_front(); + }, + (ErrorResponse as err) => { + if let Some(qw) = queue.back() { + qw.f.error(err); + } + }, + unknown => { + eprintln!("Unknown message: {unknown:?}"); + } + }); + } + ConnState::Connecting(..) => { + return Err(PGError::InvalidState); + } + ConnState::Error(..) | ConnState::Closed => self.check_error()?, + } + + Ok(()) + } + + pub fn task(self: Rc) -> impl Future> { + let ready_lock = self.ready_lock.clone().try_lock_owned().unwrap(); + + async move { + poll_fn(|cx| { + let mut state = self.state.borrow_mut(); + match &mut *state { + ConnState::Connecting(fut) => match fut.poll_unpin(cx) { + Poll::Ready(result) => { + let raw = match result { + Ok(raw) => raw, + Err(e) => { + let error = PGError::Connection(e); + *state = ConnState::Error(error); + return Poll::Ready(Ok::<_, PGError>(())); + } + }; + *state = ConnState::Ready(raw, VecDeque::new()); + Poll::Ready(Ok::<_, PGError>(())) + } + Poll::Pending => Poll::Pending, + }, + ConnState::Ready(..) => Poll::Ready(Ok(())), + ConnState::Error(..) | ConnState::Closed => Poll::Ready(self.check_error()), + } + }) + .await?; + + drop(ready_lock); + + let mut buffer = StructBuffer::::default(); + loop { + let mut read_buffer = [0; 1024]; + let n = poll_fn(|cx| { + self.with_stream(|stm| { + let mut buf = ReadBuf::new(&mut read_buffer); + let res = ready!(stm.poll_read(cx, &mut buf)); + Poll::Ready(res.map(|_| buf.filled().len())).map_err(PGError::Io) + })? + }) + .await?; + + println!("Read:"); + hexdump::hexdump(&read_buffer[..n]); + + buffer.push_fallible(&read_buffer[..n], |message| { + self.process_message(Some(message)) + })?; + + if n == 0 { + break; + } + } + Ok(()) + } + } + + pub async fn query( + self: Rc, + query: String, + f: impl QuerySink + 'static, + ) -> Result<(), PGError> { + trace!("Query task started: {query}"); + let mut rx = match &mut *self.state.borrow_mut() { + ConnState::Ready(_, queue) => { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let f = Box::new(ErasedQuerySink(f)) as _; + queue.push_back(QueryWaiter { + tx, + f, + data: None.into(), + }); + rx + } + x => { + warn!("Connection state was not ready: {x:?}"); + return Err(PGError::InvalidState); + } + }; + + let message = builder::Query { query: &query }.to_vec(); + self.write(&message).await?; + rx.recv().await; + Ok(()) + } +} + +#[cfg(test)] +mod tests {} diff --git a/edb/server/pgrust/src/connection/dsn.rs b/edb/server/pgrust/src/connection/dsn.rs new file mode 100644 index 000000000000..6420c48f1b73 --- /dev/null +++ b/edb/server/pgrust/src/connection/dsn.rs @@ -0,0 +1,817 @@ +//! Parses DSNs for database connections. There are some small differences with +//! how `libpq` works: +//! +//! - Unrecognized options are supported and collected in a map. +//! - `database` is recognized as an alias for `dbname` +//! - `[host1,host2]` is considered valid for psql +use super::params::*; +use super::raw_params::{Host, HostType, RawConnectionParameters}; +use percent_encoding::{percent_decode_str, utf8_percent_encode}; +use std::borrow::Cow; +use std::collections::HashMap; +use std::fs::OpenOptions; +use std::io::ErrorKind; +use std::net::{IpAddr, Ipv4Addr}; +use std::path::Path; +use std::str::FromStr; +use url::Url; + +#[cfg(windows)] +const PGPASSFILE: &str = "pgpass.conf"; +#[cfg(not(windows))] +const PGPASSFILE: &str = ".pgpass"; + +/// Aggressively encode parameters. +const ENCODING: &percent_encoding::AsciiSet = &percent_encoding::NON_ALPHANUMERIC.remove(b'_'); + +impl Password { + /// Attempt to resolve a password against the given homedir. + pub fn resolve( + &mut self, + home: &Path, + hosts: &[Host], + database: &str, + user: &str, + ) -> Result, std::io::Error> { + let passfile = match self { + Password::Unspecified => { + let passfile = home.join(PGPASSFILE); + // Don't warn about implicit missing or inaccessible files + if !matches!(passfile.try_exists(), Ok(true)) { + *self = Password::Unspecified; + return Ok(None); + } + if !passfile.is_file() { + *self = Password::Unspecified; + return Ok(None); + } + passfile + } + Password::Specified(_) => return Ok(None), + Password::Passfile(passfile) => { + let passfile = passfile.clone(); + if matches!(passfile.try_exists(), Ok(false)) { + *self = Password::Unspecified; + return Ok(Some(PasswordWarning::NotExists(passfile))); + } + if passfile.exists() && !passfile.is_file() { + *self = Password::Unspecified; + return Ok(Some(PasswordWarning::NotFile(passfile))); + } + passfile + } + }; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + + let metadata = match passfile.metadata() { + Err(err) if err.kind() == ErrorKind::PermissionDenied => { + *self = Password::Unspecified; + return Ok(Some(PasswordWarning::NotAccessible(passfile))); + } + res => res?, + }; + let permissions = metadata.permissions(); + let mode = permissions.mode(); + + if mode & (0o070) != 0 { + *self = Password::Unspecified; + return Ok(Some(PasswordWarning::Permissions(passfile, mode))); + } + } + + let file = match OpenOptions::new().read(true).open(&passfile) { + Err(err) if err.kind() == ErrorKind::PermissionDenied => { + *self = Password::Unspecified; + return Ok(Some(PasswordWarning::NotAccessible(passfile))); + } + res => res?, + }; + if let Some(password) = read_password_file( + hosts, + database, + user, + std::io::read_to_string(file)?.split('\n'), + ) { + *self = Password::Specified(password); + } else { + *self = Password::Unspecified; + } + Ok(None) + } +} + +pub trait UserProfile { + fn username(&self) -> Option>; + fn homedir(&self) -> Option>; +} + +pub trait EnvVar { + fn read(&self, name: &'static str) -> Option>; +} + +impl EnvVar for HashMap +where + K: std::hash::Hash + Eq + std::borrow::Borrow, + V: std::borrow::Borrow, +{ + fn read(&self, name: &'static str) -> Option> { + self.get(name).map(|value| value.borrow().into()) + } +} + +impl EnvVar for std::env::Vars { + fn read(&self, name: &'static str) -> Option> { + if let Ok(value) = std::env::var(name) { + Some(value.into()) + } else { + None + } + } +} + +impl EnvVar for &[(&str, &str)] { + fn read(&self, name: &'static str) -> Option> { + for (key, value) in self.iter() { + if *key == name { + return Some((*value).into()); + } + } + None + } +} + +impl EnvVar for () { + fn read(&self, _: &'static str) -> Option> { + None + } +} + +fn maybe_decode(str: Cow) -> Cow { + if str.contains('%') { + if let Ok(str) = percent_decode_str(&str).decode_utf8() { + str.into_owned().into() + } else { + str.into_owned().into() + } + } else { + str + } +} + +fn parse_port(port: &str) -> Result, ParseError> { + if port.is_empty() { + Ok(None) + } else if port.contains('%') { + let decoded = percent_decode_str(port) + .decode_utf8() + .map_err(|_| ParseError::InvalidPort(port.to_string()))?; + Ok(Some( + decoded + .parse::() + .map_err(|_| ParseError::InvalidPort(port.to_string()))?, + )) + } else { + Ok(Some( + port.parse::() + .map_err(|_| ParseError::InvalidPort(port.to_string()))?, + )) + } +} + +fn parse_hostlist( + hostspecs: I, +) -> Result<(Vec>, Vec>), ParseError> +where + I: IntoIterator, + S: AsRef, +{ + let mut hosts = vec![]; + let mut ports = vec![]; + let mut non_empty_host = false; + let mut non_empty_port = false; + + for hostspec in hostspecs { + let hostspec = hostspec.as_ref(); + let (host, port) = if let Some(port) = hostspec.strip_prefix(':') { + (None, parse_port(port)?) + } else if hostspec.starts_with('/') { + (Some(HostType::Path(hostspec.to_string())), None) + } else if hostspec.starts_with('[') { + let end_bracket = hostspec + .find(']') + .ok_or_else(|| ParseError::InvalidHostname(hostspec.to_string()))?; + + let (host_part, port_part) = hostspec.split_at(end_bracket + 1); + let host = HostType::try_from_str(&host_part[1..end_bracket])?; + + let port = if let Some(stripped) = port_part.strip_prefix(':') { + parse_port(stripped)? + } else if !port_part.is_empty() { + return Err(ParseError::InvalidHostname(hostspec.to_string())); + } else { + None + }; + (Some(host), port) + } else { + let parts: Vec<&str> = hostspec.split(':').collect(); + let addr = parts[0].to_string(); + let port = if parts.len() > 1 && !parts[1].is_empty() { + parse_port(parts[1])? + } else { + None + }; + + if let Ok(ip) = Ipv4Addr::from_str(&addr) { + (Some(HostType::IP(IpAddr::V4(ip), None)), port) + } else { + (Some(HostType::Hostname(addr)), port) + } + }; + + non_empty_host |= host.is_some(); + hosts.push(host); + non_empty_port |= port.is_some(); + ports.push(port); + } + if !non_empty_host && hosts.len() == 1 { + hosts.clear(); + } + if !non_empty_port && ports.len() == 1 { + ports.clear(); + } + Ok((hosts, ports)) +} + +pub fn parse_postgres_dsn(url_str: &str) -> Result { + let url_str = if let Some(url) = url_str.strip_prefix("postgres://") { + url + } else if let Some(url) = url_str.strip_prefix("postgresql://") { + url + } else { + return Err(ParseError::InvalidScheme( + url_str.split(':').next().unwrap_or_default().to_owned(), + )); + }; + + // Validate percent encoding + let mut chars = url_str.chars().peekable(); + while let Some(c) = chars.next() { + if c == '%' { + let hex1 = chars.next().ok_or(ParseError::InvalidPercentEncoding)?; + let hex2 = chars.next().ok_or(ParseError::InvalidPercentEncoding)?; + + if !hex1.is_ascii_hexdigit() || !hex2.is_ascii_hexdigit() { + return Err(ParseError::InvalidPercentEncoding); + } + + // Check for %00 + if hex1 == '0' && hex2 == '0' { + return Err(ParseError::InvalidPercentEncoding); + } + } + } + + // Postgres allows for hostnames surrounded by [] to contain pathnames + let (authority, path_and_query) = { + let mut in_brackets = false; + let mut chars = url_str.char_indices(); + + loop { + if let Some((i, c)) = chars.next() { + match c { + '[' => in_brackets = true, + ']' => in_brackets = false, + '?' | '/' if !in_brackets => { + break url_str.split_at(i); + } + _ => {} + } + } else { + if in_brackets { + return Err(ParseError::InvalidHostname(url_str.to_string())); + } + break (url_str, ""); + } + } + }; + + let (auth, host) = match authority.split_once('@') { + Some((auth, host)) => (auth, host), + None => ("", authority), + }; + + let url = Url::parse(&format!("unused://{auth}@host{path_and_query}"))?; + + let mut raw_params = RawConnectionParameters::<'static>::default(); + + if host.is_empty() { + raw_params.host = None; + raw_params.port = None; + } else { + let (hosts, ports) = parse_hostlist(maybe_decode(host.into()).split(','))?; + if !hosts.is_empty() { + raw_params.host = Some(hosts); + } + if !ports.is_empty() { + raw_params.port = Some(ports); + } + }; + + raw_params.user = match url.username() { + "" => None, + user => { + let decoded = percent_decode_str(user); + if let Ok(user) = decoded.decode_utf8() { + Some(Cow::Owned(user.to_string())) + } else { + Some(Cow::Owned(user.to_string())) + } + } + }; + raw_params.password = url + .password() + .map(|p| p.into()) + .map(maybe_decode) + .map(|s| s.into_owned()) + .map(Cow::Owned); + raw_params.dbname = match url.path() { + "" | "/" => None, + path => Some(Cow::Owned(path.trim_start_matches('/').to_string())), + } + .map(maybe_decode) + .map(|s| s.into_owned()) + .map(Cow::Owned); + + // Validate URL query parameters + let query_str = url.query().unwrap_or(""); + let key_value_pairs = query_str.split('&'); + + for pair in key_value_pairs { + if pair.is_empty() { + continue; + } + + if !pair.contains('=') { + return Err(ParseError::InvalidQueryParameter(pair.to_string())); + } + + let parts: Vec<&str> = pair.split('=').collect(); + if parts.len() > 2 { + return Err(ParseError::InvalidQueryParameter(pair.to_string())); + } + + if parts[0].is_empty() { + return Err(ParseError::InvalidQueryParameter(pair.to_string())); + } + } + + for (mut name, value) in url.query_pairs() { + // Intentional difference: database is an alias for dbname + if name == "database" { + name = Cow::Borrowed("dbname"); + } + + raw_params.set_by_name(&name, value.into_owned().into())?; + } + + Ok(raw_params) +} + +pub fn parse_postgres_dsn_env( + url_str: &str, + env: impl EnvVar, +) -> Result { + let mut raw_params = parse_postgres_dsn(url_str)?; + raw_params.apply_env(env)?; + raw_params.try_into() +} + +pub(crate) fn params_to_url(params: &RawConnectionParameters) -> String { + let mut url = String::from("postgresql://"); + let mut params_vec: Vec<(Cow<'_, str>, Cow<'_, str>)> = Vec::new(); + + // Add user and password if present + if let Some(user) = ¶ms.user { + url.extend(utf8_percent_encode(user, ENCODING)); + if let Some(password) = ¶ms.password { + url.push(':'); + url.extend(utf8_percent_encode(password, ENCODING)); + } + url.push('@'); + } else if let Some(password) = ¶ms.password { + url.push(':'); + url.extend(utf8_percent_encode(password, ENCODING)); + url.push('@'); + } + + // Add host and port + let host_count = params.host.as_ref().map_or(0, |h| h.len()); + let port_count = params.port.as_ref().map_or(0, |p| p.len()); + + if host_count <= 1 && port_count <= 1 { + // Add host to authority part + if let Some(hosts) = ¶ms.host { + if let Some(Some(host)) = hosts.first() { + match host { + HostType::Hostname(h) => { + url.push_str(h); + } + HostType::IP(ip, Some(h)) => { + url.push('['); + url.push_str(&ip.to_string()); + url.push_str("%25"); + url.push_str(h); + url.push(']'); + } + HostType::IP(ip, None) => { + url.push('['); + url.push_str(&ip.to_string()); + url.push(']'); + } + _ => { + // Unix socket paths go in the params, not in the authority + params_vec.push((Cow::Borrowed("host"), Cow::Owned(host.to_string()))); + } + } + } else { + params_vec.push((Cow::Borrowed("host"), Cow::Borrowed(""))); + } + } + + // Add port to authority part + if let Some(ports) = ¶ms.port { + match ports.first() { + Some(Some(port)) => { + url.push(':'); + url.push_str(&port.to_string()); + } + Some(None) => { + params_vec.push((Cow::Borrowed("port"), Cow::Borrowed(""))); + } + None => {} + } + } + } else { + // Add hosts to params_vec + if let Some(hosts) = ¶ms.host { + let host_str: String = hosts + .iter() + .map(|h| h.as_ref().map_or("".to_string(), |h| h.to_string())) + .collect::>() + .join(","); + params_vec.push((Cow::Borrowed("host"), Cow::Owned(host_str))); + } + + // Add ports to params_vec + if let Some(ports) = ¶ms.port { + let port_str: String = ports + .iter() + .map(|&p| p.map_or("".to_string(), |p| p.to_string())) + .collect::>() + .join(","); + params_vec.push((Cow::Borrowed("port"), Cow::Owned(port_str))); + } + } + + // Add database if present + if let Some(db) = ¶ms.dbname { + url.push('/'); + url.extend(utf8_percent_encode(db, ENCODING)); + } + + // Add other parameters + let mut has_query = false; + + if !params_vec.is_empty() { + has_query = true; + } + + if !params_vec.is_empty() { + url.push('?'); + url.push_str( + ¶ms_vec + .iter() + .map(|(k, v)| { + format!( + "{}={}", + utf8_percent_encode(k, ENCODING), + utf8_percent_encode(v, ENCODING) + ) + }) + .collect::>() + .join("&"), + ); + } + + params.visit_query_only(|key, value| { + if !has_query { + url.push('?'); + has_query = true; + } else { + url.push('&'); + } + url.extend(utf8_percent_encode(key, ENCODING)); + url.push('='); + url.extend(utf8_percent_encode(value, ENCODING)); + }); + + url +} + +fn read_password_file( + hosts: &[Host], + database: &str, + user: &str, + reader: impl Iterator>, +) -> Option { + for line in reader { + let line = line.as_ref().trim(); + + if line.is_empty() || line.starts_with('#') { + continue; + } + + let mut parts = vec![String::new()]; + let mut backslash = false; + for c in line.chars() { + if backslash { + parts.last_mut().unwrap().push(c); + backslash = false; + continue; + } + if c == '\\' { + backslash = true; + continue; + } + if c == ':' && parts.len() <= 4 { + parts.push(String::new()); + continue; + } + parts.last_mut().unwrap().push(c); + } + + if parts.len() == 5 { + for host in hosts { + match &host.0 { + HostType::Hostname(hostname) => { + if parts[0] != "*" && parts[0] != hostname.as_str() { + continue; + } + } + HostType::IP(hostname, _) => { + if parts[0] != "*" && str::parse(&parts[0]) != Ok(*hostname) { + continue; + } + } + HostType::Path(_) | HostType::Abstract(_) => { + if parts[0] != "*" && parts[0] != "localhost" { + continue; + } + } + }; + if parts[1] != "*" && str::parse(&parts[1]) != Ok(host.1) { + continue; + } + if parts[2] != "*" && parts[2] != database { + continue; + } + if parts[3] != "*" && parts[3] != user { + continue; + } + return Some(parts.pop().unwrap()); + } + } + } + + None +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::super::raw_params::SslMode; + use super::*; + use pretty_assertions::assert_eq; + use rstest::rstest; + + #[rstest] + #[case( + &[":1"], + Ok((vec![], vec![Some(1)])) + )] + #[case( + &[":1", ":2"], + Ok((vec!["", ""], vec![Some(1), Some(2)])) + )] + #[case( + &["hostname"], + Ok((vec!["hostname"], vec![])) + )] + #[case( + &["hostname:4321"], + Ok((vec!["hostname"], vec![Some(4321)])) + )] + #[case( + &["/path"], + Ok((vec!["/path"], vec![])) + )] + #[case( + &["[2001:db8::1234]", "[::1]"], + Ok((vec!["2001:db8::1234", "::1"], vec![None, None])) + )] + #[case( + &["[2001:db8::1234%eth0]"], + Ok((vec!["2001:db8::1234%eth0"], vec![])) + )] + #[case( + &["[::1]z"], + Err(ParseError::InvalidHostname("[::1]z".to_owned())) + )] + fn test_parse_hostlist( + #[case] input: &[&str], + #[case] expected: Result<(Vec<&'static str>, Vec>), ParseError>, + ) { + let result = parse_hostlist(input); + let expected_host_types = expected.map(|(hosts, ports)| { + ( + hosts + .into_iter() + .map(|h| HostType::try_from_str(h).ok()) + .collect(), + ports, + ) + }); + assert_eq!(expected_host_types, result); + } + + #[test] + fn test_parse_password_file() { + let input = r#" +abc:*:*:user:password from pgpass for user@abc +localhost:*:*:*:password from pgpass for localhost +cde:5433:*:*:password from pgpass for cde:5433 + +*:*:*:testuser:password from pgpass for testuser +*:*:testdb:*:password from pgpass for testdb +# comment +*:*:test\:db:test\\:password from pgpass with escapes + "# + .trim(); + + for (host, database, user, output) in [ + ( + Host(HostType::Hostname("abc".to_owned()), 1234), + "database", + "user", + Some("password from pgpass for user@abc"), + ), + ( + Host(HostType::Hostname("localhost".to_owned()), 1234), + "database", + "user", + Some("password from pgpass for localhost"), + ), + ( + Host(HostType::Path("/tmp".into()), 1234), + "database", + "user", + Some("password from pgpass for localhost"), + ), + ( + Host(HostType::Hostname("hmm".to_owned()), 1234), + "database", + "testuser", + Some("password from pgpass for testuser"), + ), + ( + Host(HostType::Hostname("hostname".to_owned()), 1234), + "test:db", + r#"test\"#, + Some("password from pgpass with escapes"), + ), + ( + Host(HostType::Hostname("doesntexist".to_owned()), 1234), + "db", + "user", + None, + ), + ] { + assert_eq!( + read_password_file(&[host], database, user, input.split('\n')), + output.map(|s| s.to_owned()) + ); + } + } + + #[test] + fn test_parse_dsn() { + assert_eq!( + parse_postgres_dsn_env( + "postgres://", + [ + ("PGUSER", "user"), + ("PGDATABASE", "testdb"), + ("PGPASSWORD", "passw"), + ("PGHOST", "host"), + ("PGPORT", "123"), + ("PGCONNECT_TIMEOUT", "8"), + ] + .as_slice() + ) + .unwrap(), + ConnectionParameters { + hosts: vec![Host(HostType::Hostname("host".to_string()), 123)], + database: "testdb".to_string(), + user: "user".to_string(), + password: Password::Specified("passw".to_string()), + connect_timeout: Some(Duration::from_secs(8)), + ssl: Ssl::Enable(SslMode::Prefer, Default::default()), + ..Default::default() + } + ); + + assert_eq!( + parse_postgres_dsn_env("postgres://user:pass@host:1234/database", ()).unwrap(), + ConnectionParameters { + hosts: vec![Host(HostType::Hostname("host".to_string()), 1234)], + database: "database".to_string(), + user: "user".to_string(), + password: Password::Specified("pass".to_string()), + ssl: Ssl::Enable(SslMode::Prefer, Default::default()), + ..Default::default() + } + ); + + assert_eq!( + parse_postgres_dsn_env("postgresql://user@host1:1111,host2:2222/db", ()).unwrap(), + ConnectionParameters { + hosts: vec![ + Host(HostType::Hostname("host1".to_string()), 1111), + Host(HostType::Hostname("host2".to_string()), 2222), + ], + database: "db".to_string(), + user: "user".to_string(), + password: Password::Unspecified, + ssl: Ssl::Enable(SslMode::Prefer, Default::default()), + ..Default::default() + } + ); + } + + #[test] + fn test_dsn_with_slashes() { + assert_eq!( + parse_postgres_dsn_env( + r#"postgres://test\\@fgh/test\:db?passfile=/tmp/tmpkrjuaje4"#, + () + ) + .unwrap(), + ConnectionParameters { + hosts: vec![Host(HostType::Hostname("fgh".to_string()), 5432)], + database: r#"test\:db"#.to_string(), + user: r#"test\\"#.to_string(), + password: Password::Passfile("/tmp/tmpkrjuaje4".to_string().into()), + ssl: Ssl::Enable(SslMode::Prefer, Default::default()), + ..Default::default() + } + ); + } + + #[test] + fn test_dsn_with_escapes() { + assert_eq!( + parse_postgres_dsn_env("postgresql://us%40r:p%40ss@h%40st1,h%40st2:543%33/d%62", ()) + .unwrap(), + ConnectionParameters { + hosts: vec![ + Host(HostType::Hostname("h@st1".to_string()), 5432), + Host(HostType::Hostname("h@st2".to_string()), 5433), + ], + database: "db".to_string(), + user: "us@r".to_string(), + password: Password::Specified("p@ss".to_string()), + ssl: Ssl::Enable(SslMode::Prefer, Default::default()), + ..Default::default() + } + ); + } + + #[test] + fn test_dsn_no_slash() { + assert_eq!( + parse_postgres_dsn_env("postgres://user@?port=56226&host=%2Ftmp", ()).unwrap(), + ConnectionParameters { + hosts: vec![Host(HostType::Path("/tmp".to_string()), 56226)], + database: "user".to_string(), + user: "user".to_string(), + password: Password::Unspecified, + ssl: Ssl::Disable, + ..Default::default() + } + ); + } +} diff --git a/edb/server/pgrust/src/connection/mod.rs b/edb/server/pgrust/src/connection/mod.rs new file mode 100644 index 000000000000..8eb19ddb380b --- /dev/null +++ b/edb/server/pgrust/src/connection/mod.rs @@ -0,0 +1,134 @@ +use std::collections::HashMap; + +use crate::auth::{self}; +mod conn; +pub mod dsn; +mod openssl; +pub mod params; +mod raw_conn; +mod raw_params; +pub(crate) mod state_machine; +mod stream; +pub mod tokio; + +pub use conn::Client; +pub use dsn::parse_postgres_dsn_env; +pub use raw_conn::connect_raw_ssl; +pub use raw_params::{Host, HostType, RawConnectionParameters, SslMode, SslVersion}; +pub use state_machine::{Authentication, ConnectionSslRequirement}; + +macro_rules! __invalid_state { + ($error:literal) => {{ + eprintln!( + "Invalid connection state: {}\n{}", + $error, + ::std::backtrace::Backtrace::capture() + ); + #[allow(deprecated)] + $crate::connection::ConnectionError::__InvalidState + }}; +} +pub(crate) use __invalid_state as invalid_state; + +#[derive(Debug, thiserror::Error)] +pub enum ConnectionError { + /// Invalid state error, suggesting a logic error in code rather than a server or client failure. + /// Use the `invalid_state!` macro instead which will print a backtrace. + #[error("Invalid state")] + #[deprecated = "Use invalid_state!"] + __InvalidState, + + /// Error returned by the server. + #[error("Server error: {code}: {message}")] + ServerError { + code: String, + message: String, + extra: HashMap, + }, + + /// The server sent something we didn't expect + #[error("Unexpected server response: {0}")] + UnexpectedServerResponse(String), + + /// Error related to SCRAM authentication. + #[error("SCRAM: {0}")] + Scram(#[from] auth::SCRAMError), + + /// I/O error encountered during connection operations. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + /// UTF-8 decoding error. + #[error("UTF8 error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + + /// SSL-related error. + #[error("SSL error: {0}")] + SslError(#[from] SslError), +} + +#[derive(Debug, thiserror::Error)] +pub enum SslError { + #[error("SSL is not supported by this client transport")] + SslUnsupportedByClient, + #[error("SSL was required by the client, but not offered by server (rejected SSL)")] + SslRequiredByClient, + #[error("OpenSSL error: {0}")] + OpenSslError(#[from] ::openssl::ssl::Error), + #[error("OpenSSL error: {0}")] + OpenSslErrorStack(#[from] ::openssl::error::ErrorStack), +} + +#[derive(Clone, Default, derive_more::Debug)] +pub struct Credentials { + pub username: String, + #[debug(skip)] + pub password: String, + pub database: String, + pub server_settings: HashMap, +} + +/// Enum representing the field types in ErrorResponse and NoticeResponse messages. +/// +/// See +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, derive_more::TryFrom)] +#[try_from(repr)] +pub enum ServerErrorField { + /// Severity: ERROR, FATAL, PANIC, WARNING, NOTICE, DEBUG, INFO, or LOG + Severity = b'S', + /// Severity (non-localized): ERROR, FATAL, PANIC, WARNING, NOTICE, DEBUG, INFO, or LOG + SeverityNonLocalized = b'V', + /// SQLSTATE code for the error + Code = b'C', + /// Primary human-readable error message + Message = b'M', + /// Optional secondary error message with more detail + Detail = b'D', + /// Optional suggestion on how to resolve the problem + Hint = b'H', + /// Error cursor position as an index into the original query string + Position = b'P', + /// Internal position for internally generated commands + InternalPosition = b'p', + /// Text of a failed internally-generated command + InternalQuery = b'q', + /// Context in which the error occurred (e.g., call stack traceback) + Where = b'W', + /// Schema name associated with the error + SchemaName = b's', + /// Table name associated with the error + TableName = b't', + /// Column name associated with the error + ColumnName = b'c', + /// Data type name associated with the error + DataTypeName = b'd', + /// Constraint name associated with the error + ConstraintName = b'n', + /// Source-code file name where the error was reported + File = b'F', + /// Source-code line number where the error was reported + Line = b'L', + /// Source-code routine name reporting the error + Routine = b'R', +} diff --git a/edb/server/pgrust/src/connection/openssl.rs b/edb/server/pgrust/src/connection/openssl.rs new file mode 100644 index 000000000000..36c69b630297 --- /dev/null +++ b/edb/server/pgrust/src/connection/openssl.rs @@ -0,0 +1,121 @@ +use std::pin::Pin; + +use openssl::{ + ssl::{SslContextBuilder, SslVerifyMode}, + x509::verify::X509VerifyFlags, +}; + +use super::{ + params::SslParameters, + raw_params::SslMode, + stream::{Stream, StreamWithUpgrade}, + SslError, +}; + +impl StreamWithUpgrade for (S, openssl::ssl::Ssl) { + type Base = S; + type Config = openssl::ssl::Ssl; + type Upgrade = tokio_openssl::SslStream; + + async fn secure_upgrade(self) -> Result + where + Self: Sized, + { + let mut stream = + tokio_openssl::SslStream::new(self.1, self.0).map_err(SslError::OpenSslErrorStack)?; + Pin::new(&mut stream) + .do_handshake() + .await + .map_err(SslError::OpenSslError)?; + Ok(stream) + } +} + +/// Given a set of [`SslParameters`], configures an OpenSSL context. +pub fn create_ssl_client_context( + mut ssl: SslContextBuilder, + ssl_mode: SslMode, + parameters: SslParameters, +) -> Result> { + let SslParameters { + cert, + key, + password, + rootcert, + crl, + min_protocol_version, + max_protocol_version, + keylog_filename, + verify_crl_check_chain, + } = parameters; + + if ssl_mode >= SslMode::Require { + // Load root cert + if let Some(root) = rootcert { + ssl.set_ca_file(root)?; + ssl.set_verify(SslVerifyMode::PEER); + } else if ssl_mode == SslMode::Require { + ssl.set_verify(SslVerifyMode::NONE); + } + + // Load CRL + if let Some(crl) = &crl { + ssl.set_ca_file(crl)?; + ssl.verify_param_mut() + .set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?; + } + } + + // Load certificate chain and private key + if let (Some(cert), Some(key)) = (cert.as_ref(), key.as_ref()) { + let builder = openssl::x509::X509::from_pem(&std::fs::read(cert)?)?; + ssl.set_certificate(&builder)?; + let key = std::fs::read(key)?; + let key = if let Some(password) = password { + openssl::pkey::PKey::private_key_from_pem_passphrase(&key, password.as_bytes())? + } else { + openssl::pkey::PKey::private_key_from_pem(&key)? + }; + ssl.set_private_key(&key)?; + } + + // Configure hostname verification + if ssl_mode == SslMode::VerifyFull { + ssl.set_verify(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT); + } + + ssl.set_min_proto_version(min_protocol_version.map(|s| s.into()))?; + ssl.set_max_proto_version(max_protocol_version.map(|s| s.into()))?; + + // // Configure key log filename + // if let Some(keylog_filename) = ¶meters.keylog_filename { + // context_builder.set_keylog_file(keylog_filename)?; + // } + + Ok(ssl) +} + +#[cfg(test)] +mod tests { + use openssl::ssl::SslMethod; + use std::path::Path; + + use super::*; + + #[test] + fn create_ssl() { + let cert_path = Path::new("../../../tests/certs").canonicalize().unwrap(); + + let ssl = SslContextBuilder::new(SslMethod::tls()).unwrap(); + let ssl = create_ssl_client_context( + ssl, + SslMode::VerifyFull, + SslParameters { + cert: Some(cert_path.join("client.cert.pem")), + key: Some(cert_path.join("client.key.pem")), + ..Default::default() + }, + ) + .unwrap(); + } +} diff --git a/edb/server/pgrust/src/connection/params.rs b/edb/server/pgrust/src/connection/params.rs new file mode 100644 index 000000000000..fd7b33ec76ec --- /dev/null +++ b/edb/server/pgrust/src/connection/params.rs @@ -0,0 +1,412 @@ +use super::dsn::{parse_postgres_dsn, EnvVar}; +use super::raw_params::{Host, HostType, RawConnectionParameters, SslMode, SslVersion}; +use serde_derive::Serialize; +use std::collections::HashMap; +use std::io::ErrorKind; +use std::path::{Path, PathBuf}; +use std::time::Duration; +use thiserror::Error; + +impl<'a, I: Into>, E: EnvVar> TryInto for (I, E) { + type Error = ParseError; + fn try_into(self) -> Result { + let mut raw = self.0.into(); + raw.apply_env(self.1)?; + raw.try_into() + } +} + +impl TryInto for String { + type Error = ParseError; + fn try_into(self) -> Result { + let params = parse_postgres_dsn(&self)?; + params.try_into() + } +} + +#[derive(Error, Debug, PartialEq, Eq)] +#[allow(clippy::enum_variant_names)] +pub enum ParseError { + #[error( + "Invalid DSN: scheme is expected to be either \"postgresql\" or \"postgres\", got {0}" + )] + InvalidScheme(String), + + #[error("Invalid value for parameter \"{0}\": \"{1}\"")] + InvalidParameter(String, String), + + #[error("Invalid percent encoding")] + InvalidPercentEncoding, + + #[error("Invalid port: \"{0}\"")] + InvalidPort(String), + + #[error("Unexpected number of ports, must be either a single port or the same number as the host count: \"{0}\"")] + InvalidPortCount(String), + + #[error("Invalid hostname: \"{0}\"")] + InvalidHostname(String), + + #[error("Invalid query parameter: \"{0}\"")] + InvalidQueryParameter(String), + + #[error("Invalid TLS version: \"{0}\"")] + InvalidTLSVersion(String), + + #[error("Could not determine the connection {0}")] + MissingRequiredParameter(String), + + #[error("URL parse error: {0}")] + UrlParseError(#[from] url::ParseError), +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] +pub enum Password { + /// The password is unspecified and should be read from the user's default + /// passfile if it exists. + #[default] + Unspecified, + /// The password was specified. + Specified(String), + /// The passfile is specified. + Passfile(PathBuf), +} + +#[derive(Serialize)] +pub enum PasswordWarning { + NotFile(PathBuf), + NotExists(PathBuf), + NotAccessible(PathBuf), + Permissions(PathBuf, u32), +} + +impl std::fmt::Display for PasswordWarning { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PasswordWarning::NotFile(path) => write!(f, "Password file {path:?} is not a plain file"), + PasswordWarning::NotExists(path) => write!(f, "Password file {path:?} does not exist"), + PasswordWarning::NotAccessible(path) => write!(f, "Password file {path:?} is not accessible"), + PasswordWarning::Permissions(path, mode) => write!(f, "Password file {path:?} has group or world access ({mode:o}); permissions should be u=rw (0600) or less"), + } + } +} + +impl Password { + pub fn password(&self) -> Option<&str> { + match self { + Password::Specified(password) => Some(password), + _ => None, + } + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)] +pub struct ConnectionParameters { + pub hosts: Vec, + pub database: String, + pub user: String, + pub password: Password, + pub connect_timeout: Option, + pub server_settings: HashMap, + pub ssl: Ssl, +} + +impl From for RawConnectionParameters<'static> { + fn from(val: ConnectionParameters) -> Self { + let mut raw_params = RawConnectionParameters::default(); + + if !val.hosts.is_empty() { + let hosts: Vec> = + val.hosts.iter().map(|h| Some(h.0.clone())).collect(); + raw_params.host = Some(hosts); + let ports: Vec> = val.hosts.iter().map(|h| Some(h.1)).collect(); + raw_params.port = Some(ports); + } + + raw_params.dbname = Some(val.database.into()); + raw_params.user = Some(val.user.into()); + + match val.password { + Password::Specified(ref pw) => { + raw_params.password = Some(pw.to_string().into()); + } + Password::Passfile(ref path) => { + raw_params.passfile = Some(path.clone().into()); + } + _ => {} + } + + if let Some(timeout) = val.connect_timeout { + raw_params.connect_timeout = Some(timeout.as_secs() as isize); + } + + match val.ssl { + Ssl::Disable => { + raw_params.sslmode = Some(SslMode::Disable); + } + Ssl::Enable(mode, ref params) => { + raw_params.sslmode = Some(mode); + if let Some(ref cert) = params.cert { + raw_params.sslcert = Some(cert.clone().into()); + } + if let Some(ref key) = params.key { + raw_params.sslkey = Some(key.clone().into()); + } + if let Some(ref password) = params.password { + raw_params.sslpassword = Some(password.to_string().into()); + } + if let Some(ref rootcert) = params.rootcert { + raw_params.sslrootcert = Some(rootcert.clone().into()); + } + if let Some(ref crl) = params.crl { + raw_params.sslcrl = Some(crl.clone().into()); + } + raw_params.ssl_min_protocol_version = params.min_protocol_version; + raw_params.ssl_max_protocol_version = params.max_protocol_version; + } + } + + raw_params.server_settings = Some( + val.server_settings + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(), + ); + + raw_params + } +} + +impl TryFrom> for ConnectionParameters { + type Error = ParseError; + + fn try_from(raw_params: RawConnectionParameters<'_>) -> Result { + fn merge_hosts_and_ports( + host_types: &[Option], + mut specified_ports: &[Option], + ) -> Result, ParseError> { + let mut hosts = vec![]; + + if host_types.is_empty() { + return merge_hosts_and_ports( + &[ + Some(HostType::Path("/var/run/postgresql".to_string())), + Some(HostType::Path("/run/postgresql".to_string())), + Some(HostType::Path("/tmp".to_string())), + Some(HostType::Path("/private/tmp".to_string())), + Some(HostType::Hostname("localhost".to_string())), + ], + specified_ports, + ); + } + + if specified_ports.is_empty() { + specified_ports = &[Some(5432)]; + } else if specified_ports.len() != host_types.len() && specified_ports.len() > 1 { + return Err(ParseError::InvalidPortCount(format!("{specified_ports:?}"))); + } + + for (i, host_type) in host_types.iter().enumerate() { + let host_type = host_type + .clone() + .unwrap_or_else(|| HostType::Path("/var/run/postgresql".to_string())); + let port = specified_ports[i % specified_ports.len()].unwrap_or(5432); + + hosts.push(Host(host_type, port)); + } + Ok(hosts) + } + + let hosts = merge_hosts_and_ports( + &raw_params.host.unwrap_or_default(), + &raw_params.port.unwrap_or_default(), + )?; + + if hosts.is_empty() { + return Err(ParseError::MissingRequiredParameter("host".to_string())); + } + + let user = raw_params + .user + .ok_or_else(|| ParseError::MissingRequiredParameter("user".to_string()))?; + let database = raw_params.dbname.unwrap_or_else(|| user.clone()); + + let password = match raw_params.password { + Some(p) => Password::Specified(p.into_owned()), + None => match raw_params.passfile { + Some(passfile) => Password::Passfile(passfile.into_owned()), + None => Password::Unspecified, + }, + }; + + let connect_timeout = raw_params.connect_timeout.and_then(|seconds| { + if seconds <= 0 { + None + } else { + Some(Duration::from_secs(seconds.max(2) as u64)) + } + }); + + let any_tcp = hosts + .iter() + .any(|host| matches!(host.0, HostType::Hostname(..) | HostType::IP(..))); + + let ssl_mode = raw_params.sslmode.unwrap_or({ + if any_tcp { + SslMode::Prefer + } else { + SslMode::Disable + } + }); + + let ssl = if ssl_mode == SslMode::Disable { + Ssl::Disable + } else { + let mut ssl = SslParameters::default(); + if ssl_mode >= SslMode::Require { + ssl.rootcert = raw_params.sslrootcert.map(|s| s.into_owned()); + ssl.crl = raw_params.sslcrl.map(|s| s.into_owned()); + } + ssl.key = raw_params.sslkey.map(|s| s.into_owned()); + ssl.cert = raw_params.sslcert.map(|s| s.into_owned()); + ssl.min_protocol_version = raw_params.ssl_min_protocol_version; + ssl.max_protocol_version = raw_params.ssl_max_protocol_version; + ssl.password = raw_params.sslpassword.map(|s| s.into_owned()); + + Ssl::Enable(ssl_mode, ssl) + }; + + Ok(ConnectionParameters { + hosts, + database: database.into_owned(), + user: user.into_owned(), + password, + connect_timeout, + server_settings: raw_params + .server_settings + .unwrap_or_default() + .into_iter() + .map(|(k, v)| (k.into_owned(), v.into_owned())) + .collect(), + ssl, + }) + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)] +#[allow(clippy::large_enum_variant)] +pub enum Ssl { + #[default] + Disable, + Enable(SslMode, SslParameters), +} + +#[derive(Default, Clone, Debug, PartialEq, Eq, Serialize)] +pub struct SslParameters { + pub cert: Option, + pub key: Option, + pub password: Option, + pub rootcert: Option, + pub crl: Option, + pub min_protocol_version: Option, + pub max_protocol_version: Option, + pub keylog_filename: Option, + pub verify_crl_check_chain: Option, +} + +impl Ssl { + /// Resolve the SSL paths relative to the home directory. + pub fn resolve(&mut self, home_dir: &Path) -> Result<(), std::io::Error> { + let postgres_dir = home_dir; + let Ssl::Enable(mode, params) = self else { + return Ok(()); + }; + if *mode >= SslMode::Require { + let root_cert = params + .rootcert + .clone() + .unwrap_or_else(|| postgres_dir.join("root.crt")); + if root_cert.exists() { + params.rootcert = Some(root_cert); + } else if *mode > SslMode::Require { + return Err(std::io::Error::new(ErrorKind::NotFound, + format!("Root certificate not found: {root_cert:?}. Either provide the file or change sslmode to disable SSL certificate verification."))); + } + + let crl = params + .crl + .clone() + .unwrap_or_else(|| postgres_dir.join("root.crl")); + if crl.exists() { + params.crl = Some(crl); + } + } + let key = params + .key + .clone() + .unwrap_or_else(|| postgres_dir.join("postgresql.key")); + if key.exists() { + params.key = Some(key); + } + let cert = params + .cert + .clone() + .unwrap_or_else(|| postgres_dir.join("postgresql.crt")); + if cert.exists() { + params.cert = Some(cert); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::{HostType, ParseError}; + use rstest::rstest; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + #[rstest] + #[case("example.com", HostType::Hostname("example.com".to_string()))] + // This should probably parse as IPv4 + #[case("0", HostType::Hostname("0".to_string()))] + #[case( + "192.168.1.1", + HostType::IP(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), None) + )] + #[case( + "2001:db8::1", + HostType::IP(IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), None) + )] + #[case("2001:db8::1%eth0", HostType::IP(IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), Some("eth0".to_string())))] + #[case("/var/run/postgresql", HostType::Path("/var/run/postgresql".to_string()))] + #[case("@abstract", HostType::Abstract("abstract".to_string()))] + fn test_host_type_roundtrip(#[case] input: &str, #[case] expected: HostType) { + let parsed = HostType::try_from_str(input).unwrap(); + assert_eq!(parsed, expected, "{input} should have succeeded"); + assert_eq!(parsed.to_string(), input, "{input} should have succeeded"); + } + + #[rstest] + #[case("", ParseError::InvalidHostname("".to_string()))] + #[case("example.com:80", ParseError::InvalidHostname("example.com:80".to_string()))] + #[case("[::1]", ParseError::InvalidHostname("[::1]".to_string()))] + #[case("2001:db8::1%", ParseError::InvalidHostname("2001:db8::1%".to_string()))] + #[case("not:valid:ipv6", ParseError::InvalidHostname("not:valid:ipv6".to_string()))] + fn test_host_type_failures(#[case] input: &str, #[case] expected_error: ParseError) { + let result = HostType::try_from_str(input); + assert!(result.is_err(), "{input} should have failed"); + assert_eq!( + result.unwrap_err(), + expected_error, + "{input} should have failed" + ); + } + + #[test] + fn resolve() { + eprintln!( + "{:?}", + HostType::Hostname("fe80::127c:61ff:fe3d:16d5%lo".to_owned()).resolve() + ); + } +} diff --git a/edb/server/pgrust/src/connection/raw_conn.rs b/edb/server/pgrust/src/connection/raw_conn.rs new file mode 100644 index 000000000000..abeda2310330 --- /dev/null +++ b/edb/server/pgrust/src/connection/raw_conn.rs @@ -0,0 +1,240 @@ +use super::state_machine::{ + Authentication, ConnectionDrive, ConnectionSslRequirement, ConnectionState, + ConnectionStateSend, ConnectionStateType, ConnectionStateUpdate, +}; +use super::{ + stream::{Stream, StreamWithUpgrade, UpgradableStream}, + ConnectionError, Credentials, +}; +use crate::protocol::{meta, SSLResponse, StructBuffer}; +use std::collections::HashMap; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncWriteExt; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tracing::trace; + +#[derive(Clone, Default, Debug)] +pub struct ConnectionParams { + pub ssl: bool, + pub params: HashMap, + pub cancellation_key: (i32, i32), + pub auth: Authentication, +} + +pub struct ConnectionDriver { + send_buffer: Vec, + upgrade: bool, + params: ConnectionParams, +} + +impl ConnectionStateSend for ConnectionDriver { + fn send_initial( + &mut self, + message: crate::protocol::definition::InitialBuilder, + ) -> Result<(), std::io::Error> { + self.send_buffer.extend(message.to_vec()); + Ok(()) + } + fn send( + &mut self, + message: crate::protocol::definition::FrontendBuilder, + ) -> Result<(), std::io::Error> { + self.send_buffer.extend(message.to_vec()); + Ok(()) + } + fn upgrade(&mut self) -> Result<(), std::io::Error> { + self.upgrade = true; + self.params.ssl = true; + Ok(()) + } +} + +impl ConnectionStateUpdate for ConnectionDriver { + fn state_changed(&mut self, state: ConnectionStateType) { + trace!("State: {state:?}"); + } + fn cancellation_key(&mut self, pid: i32, key: i32) { + self.params.cancellation_key = (pid, key); + } + fn parameter(&mut self, name: &str, value: &str) { + self.params.params.insert(name.to_owned(), value.to_owned()); + } + fn auth(&mut self, auth: Authentication) { + trace!("Auth: {auth:?}"); + self.params.auth = auth; + } +} + +impl ConnectionDriver { + pub fn new() -> Self { + Self { + send_buffer: Vec::new(), + upgrade: false, + params: ConnectionParams::default(), + } + } + + async fn drive_bytes( + &mut self, + state: &mut ConnectionState, + drive: &[u8], + message_buffer: &mut StructBuffer, + stream: &mut UpgradableStream, + ) -> Result<(), ConnectionError> + where + (B, C): StreamWithUpgrade, + { + message_buffer.push_fallible(drive, |msg| { + state.drive(ConnectionDrive::Message(msg), self) + })?; + loop { + if !self.send_buffer.is_empty() { + println!("Write:"); + hexdump::hexdump(&self.send_buffer); + stream.write_all(&self.send_buffer).await?; + self.send_buffer.clear(); + } + if self.upgrade { + self.upgrade = false; + stream.secure_upgrade().await?; + state.drive(ConnectionDrive::SslReady, self)?; + } else { + break; + } + } + Ok(()) + } + + async fn drive( + &mut self, + state: &mut ConnectionState, + drive: ConnectionDrive<'_>, + stream: &mut UpgradableStream, + ) -> Result<(), ConnectionError> + where + (B, C): StreamWithUpgrade, + { + state.drive(drive, self)?; + loop { + if !self.send_buffer.is_empty() { + println!("Write:"); + hexdump::hexdump(&self.send_buffer); + stream.write_all(&self.send_buffer).await?; + self.send_buffer.clear(); + } + if self.upgrade { + self.upgrade = false; + stream.secure_upgrade().await?; + state.drive(ConnectionDrive::SslReady, self)?; + } else { + break; + } + } + Ok(()) + } +} + +/// A raw, fully-authenticated stream connection to a backend server. +pub struct RawClient +where + (B, C): StreamWithUpgrade, +{ + stream: UpgradableStream, + params: ConnectionParams, +} + +impl RawClient +where + (B, C): StreamWithUpgrade, +{ + pub fn params(&self) -> &ConnectionParams { + &self.params + } +} + +impl AsyncRead for RawClient +where + (B, C): StreamWithUpgrade, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_read(cx, buf) + } +} + +impl AsyncWrite for RawClient +where + (B, C): StreamWithUpgrade, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_shutdown(cx) + } +} + +pub async fn connect_raw_ssl( + credentials: Credentials, + ssl_mode: ConnectionSslRequirement, + config: C, + socket: B, +) -> Result, ConnectionError> +where + (B, C): StreamWithUpgrade, +{ + let mut state = ConnectionState::new(credentials, ssl_mode); + let mut stream = UpgradableStream::from((socket, config)); + + let mut update = ConnectionDriver::new(); + update + .drive(&mut state, ConnectionDrive::Initial, &mut stream) + .await?; + + let mut struct_buffer: StructBuffer = StructBuffer::::default(); + + while !state.is_ready() { + let mut buffer = [0; 1024]; + let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await?; + if n == 0 { + Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?; + } + println!("Read:"); + hexdump::hexdump(&buffer[..n]); + if state.read_ssl_response() { + let ssl_response = SSLResponse::new(&buffer); + update + .drive( + &mut state, + ConnectionDrive::SslResponse(ssl_response), + &mut stream, + ) + .await?; + continue; + } + + update + .drive_bytes(&mut state, &buffer[..n], &mut struct_buffer, &mut stream) + .await?; + } + Ok(RawClient { + stream, + params: update.params, + }) +} diff --git a/edb/server/pgrust/src/connection/raw_params.rs b/edb/server/pgrust/src/connection/raw_params.rs new file mode 100644 index 000000000000..1c0ed77e2d56 --- /dev/null +++ b/edb/server/pgrust/src/connection/raw_params.rs @@ -0,0 +1,553 @@ +use super::dsn::params_to_url; +use super::params::ParseError; +use serde_derive::Serialize; +use std::borrow::Cow; +use std::collections::HashMap; +use std::net::{IpAddr, Ipv6Addr, SocketAddr}; +use std::path::{Path, PathBuf}; + +/// Convert from an environment variable or query string to the parameter type. +trait FromEnv +where + Self: Sized, +{ + fn from(env: Cow) -> Result; +} + +macro_rules! from_env_impl { + ($ty:ty: $expr:expr) => { + impl FromEnv for $ty { + fn from(env: Cow) -> Result { + ($expr(env)) + } + } + }; +} + +// Define one of these per param type +from_env_impl!(Vec>: |e: Cow| parse_host_param(&e)); +from_env_impl!(Vec>: |e: Cow| parse_port_param(&e)); +from_env_impl!(Cow<'_, str>: |e: Cow| Ok(e.into_owned().into())); +from_env_impl!(Cow<'_, Path>: |e: Cow| Ok(PathBuf::from(e.into_owned()).into())); +from_env_impl!(isize: |e: Cow| parse_connect_timeout(e)); +from_env_impl!(bool: |e: Cow| Ok(e == "1" || e == "true" || e == "on" || e == "yes")); +from_env_impl!(SslMode: |e: Cow| SslMode::try_from(e.as_ref())); +from_env_impl!(SslVersion: |e: Cow| e.try_into()); + +trait ToEnv { + fn to(&self) -> Cow; +} + +macro_rules! to_env_impl { + ($ty:ty: |$self:ident| $expr:expr) => { + impl ToEnv for $ty { + fn to(&self) -> Cow { + let $self = self; + $expr + } + } + }; +} + +to_env_impl!(Cow<'_, str>: |e| Cow::Borrowed(e)); +to_env_impl!(Cow<'_, Path>: |e| e.to_string_lossy()); +to_env_impl!(Vec>: |e| { + Cow::Owned(e.iter().map(|h| match h { + Some(ht) => ht.to_string(), + None => String::new(), + }).collect::>().join(",")) +}); +to_env_impl!(Vec>: |e| { + Cow::Owned(e.iter().map(|p| p.map_or(String::new(), |v| v.to_string())) + .collect::>().join(",")) +}); +to_env_impl!(isize: |e| Cow::Owned(e.to_string())); +to_env_impl!(bool: |e| Cow::Owned(if *e { "1" } else { "0" }.to_string())); +to_env_impl!(SslMode: |e| Cow::Owned(e.to_string())); +to_env_impl!(SslVersion: |e| Cow::Owned(e.to_string())); + +trait RawToOwned { + type Owned; + fn raw_to_owned(&self) -> Self::Owned; +} + +impl<'a, T: ?Sized> RawToOwned for Cow<'a, T> +where + T: ToOwned + 'static, + Cow<'static, T>: From<::Owned>, +{ + type Owned = Cow<'static, T>; + fn raw_to_owned(&self) -> ::Owned { + ToOwned::to_owned(self.as_ref()).into() + } +} + +macro_rules! trivial_raw_to_owned { + ($ty:ident $(< $($generic:ident),* >)?) => { + impl $(<$($generic),*>)? RawToOwned for $ty $(<$($generic),*>)? where Self: Clone { + type Owned = Self; + fn raw_to_owned(&self) -> Self::Owned { + self.clone() + } + } + }; +} + +trivial_raw_to_owned!(Vec); +trivial_raw_to_owned!(isize); +trivial_raw_to_owned!(bool); +trivial_raw_to_owned!(SslMode); +trivial_raw_to_owned!(SslVersion); + +macro_rules! define_params { + ($lifetime:lifetime, $( #[doc = $doc:literal] $name:ident: $ty:ty $(, env = $env:literal)? $(, query_only = $query_only:ident)?; )* ) => { + /// [`RawConnectionParameters`] represents the raw, parsed connection parameters. + /// + /// These parameters map directly to the parameters in the DSN and perform only + /// basic validation. + #[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)] + pub struct RawConnectionParameters<$lifetime> { + $( + #[doc = $doc] + pub $name: Option<$ty>, + )* + /// Any additional settings we don't recognize + pub server_settings: Option, Cow<'a, str>>>, + } + + impl<'a> From> for HashMap { + fn from(params: RawConnectionParameters<$lifetime>) -> HashMap { + let mut map = HashMap::new(); + + $( + if let Some(value) = params.$name { + map.insert(stringify!($name).to_string(), <$ty as ToEnv>::to(&value).into_owned()); + } + )* + + if let Some(server_settings) = params.server_settings { + map.extend(server_settings.into_iter().map(|(k, v)| (k.into_owned(), v.into_owned()))); + } + + map + } + } + + impl <$lifetime> RawConnectionParameters<$lifetime> { + pub fn to_static(&self) -> RawConnectionParameters<'static> { + $( + let $name = self.$name.as_ref().map(|v| v.raw_to_owned()); + )* + + let server_settings = self.server_settings.as_ref().map(|m| { + m.iter().map(|(k, v)| (k.raw_to_owned(), v.raw_to_owned())).collect() + }); + + RawConnectionParameters::<'static> { + $( + $name, + )* + server_settings, + } + } + + /// Apply environment variables to the parameters. + pub fn apply_env(&mut self, env: impl crate::connection::dsn::EnvVar) -> Result<(), ParseError> { + $( + $( + if self.$name.is_none() { + if let Some(env_value) = env.read($env) { + self.$name = Some(FromEnv::from(env_value)?); + } + } + )? + )* + Ok(()) + } + + /// Set a parameter by query string name. + pub fn set_by_name(&mut self, name: &str, value: Cow<'a, str>) -> Result<(), ParseError> { + match name { + $( + stringify!($name) => { + self.$name = Some(FromEnv::from(value)?); + }, + )* + _ => { + self.server_settings + .get_or_insert_with(HashMap::new) + .insert(Cow::Owned(name.to_string()), value); + } + } + Ok(()) + } + + /// Get a parameter by query string name. + pub fn get_by_name(&self, name: &str) -> Option> { + match name { + $( + stringify!($name) => { + self.$name.as_ref().map(|value| <$ty as ToEnv>::to(&value)) + }, + )* + _ => { + self.server_settings + .as_ref() + .and_then(|settings| settings.get(name)) + .map(|value| as ToEnv>::to(&value)) + } + } + } + + /// Visit the query-only parameters. These are the parameters that never appears anywhere other than in the query string. + pub(crate) fn visit_query_only(&self, mut f: impl for<'b> FnMut(&'b str, &'b str)) { + $( + $( + stringify!($query_only); + if let Some(value) = &self.$name { + f(stringify!($name), &value.to()); + } + )? + )* + + if let Some(settings) = &self.server_settings { + for (key, value) in settings { + f(key, value); + } + } + } + + /// Returns all field names as a vector of static string slices. + pub fn field_names() -> Vec<&'static str> { + vec![ + $( + stringify!($name), + )* + ] + } + } + }; +} + +impl<'a> RawConnectionParameters<'a> { + pub fn hosts(&self) -> Result, ParseError> { + Self::merge_hosts_and_ports( + self.host.as_deref().unwrap_or_default(), + self.port.as_deref().unwrap_or_default(), + ) + } + + fn merge_hosts_and_ports( + host_types: &[Option], + mut specified_ports: &[Option], + ) -> Result, ParseError> { + let mut hosts = vec![]; + + if host_types.is_empty() { + return Self::merge_hosts_and_ports( + &[ + Some(HostType::Path("/var/run/postgresql".to_string())), + Some(HostType::Path("/run/postgresql".to_string())), + Some(HostType::Path("/tmp".to_string())), + Some(HostType::Path("/private/tmp".to_string())), + Some(HostType::Hostname("localhost".to_string())), + ], + specified_ports, + ); + } + + if specified_ports.is_empty() { + specified_ports = &[Some(5432)]; + } else if specified_ports.len() != host_types.len() && specified_ports.len() > 1 { + return Err(ParseError::InvalidPortCount(format!("{specified_ports:?}"))); + } + + for (i, host_type) in host_types.iter().enumerate() { + let host_type = host_type + .clone() + .unwrap_or_else(|| HostType::Path("/var/run/postgresql".to_string())); + let port = specified_ports[i % specified_ports.len()].unwrap_or(5432); + + hosts.push(Host(host_type, port)); + } + Ok(hosts) + } +} + +define_params!('a, + /// The host to connect to. + host: Vec>, env = "PGHOST"; + /// The port to connect to. + port: Vec>, env = "PGPORT"; + /// The database to connect to. + dbname: Cow<'a, str>, env = "PGDATABASE"; + /// The user to connect as. + user: Cow<'a, str>, env = "PGUSER"; + /// The password to use when connecting. + password: Cow<'a, str>, env = "PGPASSWORD"; + + /// The path to the passfile. + passfile: Cow<'a, Path>, env = "PGPASSFILE", query_only = query_only; + /// The timeout for the connection to be established. + connect_timeout: isize, env = "PGCONNECT_TIMEOUT", query_only = query_only; + /// The SSL mode to use. + sslmode: SslMode, env = "PGSSLMODE", query_only = query_only; + /// The SSL certificate to use. + sslcert: Cow<'a, Path>, env = "PGSSLCERT", query_only = query_only; + /// The SSL key to use. + sslkey: Cow<'a, Path>, env = "PGSSLKEY", query_only = query_only; + /// The SSL password to use. + sslpassword: Cow<'a, str>, query_only = query_only; + /// The SSL root certificate to use. + sslrootcert: Cow<'a, Path>, env = "PGSSLROOTCERT", query_only = query_only; + /// The path to the CRL file. + sslcrl: Cow<'a, Path>, env = "PGSSLCRL", query_only = query_only; + /// The minimum SSL protocol version to use. + ssl_min_protocol_version: SslVersion, env = "PGSSLMINPROTOCOLVERSION", query_only = query_only; + /// The maximum SSL protocol version to use. + ssl_max_protocol_version: SslVersion, env = "PGSSLMAXPROTOCOLVERSION", query_only = query_only; + + /// The path to the file for TLS key log. + keylog_filename: Cow<'a, Path>; + /// Whether to verify the CRL chain. + verify_crl_check_chain: bool; +); + +impl RawConnectionParameters<'_> { + pub fn to_url(&self) -> String { + params_to_url(self) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] +pub struct Host(pub HostType, pub u16); + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] +pub enum HostType { + Hostname(String), + IP(IpAddr, Option), + Path(String), + Abstract(String), +} + +impl std::fmt::Display for HostType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HostType::Hostname(hostname) => write!(f, "{}", hostname), + HostType::IP(ip, Some(interface)) => write!(f, "{}%{}", ip, interface), + HostType::IP(ip, None) => { + write!(f, "{}", ip) + } + HostType::Path(path) => write!(f, "{}", path), + HostType::Abstract(name) => write!(f, "@{}", name), + } + } +} + +impl HostType { + pub fn try_from_str(s: &str) -> Result { + if s.is_empty() { + return Err(ParseError::InvalidHostname("".to_string())); + } + if s.contains('[') || s.contains(']') { + return Err(ParseError::InvalidHostname(s.to_string())); + } + if s.starts_with('/') { + return Ok(HostType::Path(s.to_string())); + } + if let Some(s) = s.strip_prefix('@') { + return Ok(HostType::Abstract(s.to_string())); + } + if s.contains('%') { + let (ip_str, interface) = s.split_once('%').unwrap(); + if interface.is_empty() { + return Err(ParseError::InvalidHostname(s.to_string())); + } + let ip = ip_str + .parse::() + .map_err(|_| ParseError::InvalidHostname(s.to_string()))?; + return Ok(HostType::IP(IpAddr::V6(ip), Some(interface.to_string()))); + } + if let Ok(ip) = s.parse::() { + Ok(HostType::IP(ip, None)) + } else { + if s.contains(':') { + return Err(ParseError::InvalidHostname(s.to_string())); + } + Ok(HostType::Hostname(s.to_string())) + } + } + + pub(crate) fn resolve(&self) -> std::io::Result> { + match self { + Self::Hostname(host) => { + use std::net::ToSocketAddrs; + Ok((host.as_str(), 1) + .to_socket_addrs()? + .map(|addr| { + eprintln!("{addr:?}"); + match addr { + SocketAddr::V4(addr) => HostType::IP(IpAddr::V4(*addr.ip()), None), + SocketAddr::V6(addr) => HostType::IP(IpAddr::V6(*addr.ip()), None), + } + }) + .collect()) + } + x => Ok(vec![x.clone()]), + } + } +} + +impl std::str::FromStr for HostType { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + HostType::try_from_str(s) + } +} + +/// SSL mode for PostgreSQL connections. +/// +/// For more information, see the [PostgreSQL documentation](https://www.postgresql.org/docs/current/libpq-ssl.html). +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize)] +pub enum SslMode { + /// "I don't care about security, and I don't want to pay the overhead of encryption." + #[serde(rename = "disable")] + Disable, + /// "I don't care about security, but I will pay the overhead of encryption if the server insists on it." + #[serde(rename = "allow")] + Allow, + /// "I don't care about encryption, but I wish to pay the overhead of encryption if the server supports it." + #[serde(rename = "prefer")] + Prefer, + /// "I want my data to be encrypted, and I accept the overhead. I trust that the network will make sure I always connect to the server I want." + #[serde(rename = "require")] + Require, + /// "I want my data encrypted, and I accept the overhead. I want to be sure that I connect to a server that I trust." + #[serde(rename = "verify_ca")] + VerifyCA, + /// "I want my data encrypted, and I accept the overhead. I want to be sure that I connect to a server I trust, and that it's the one I specify." + #[serde(rename = "verify_full")] + VerifyFull, +} + +impl TryFrom<&str> for SslMode { + type Error = ParseError; + + fn try_from(s: &str) -> Result { + match s { + "allow" => Ok(SslMode::Allow), + "prefer" => Ok(SslMode::Prefer), + "require" => Ok(SslMode::Require), + "verify_ca" | "verify-ca" => Ok(SslMode::VerifyCA), + "verify_full" | "verify-full" => Ok(SslMode::VerifyFull), + "disable" => Ok(SslMode::Disable), + _ => Err(ParseError::InvalidParameter( + "sslmode".to_string(), + s.to_string(), + )), + } + } +} +impl std::fmt::Display for SslMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + SslMode::Disable => "disable", + SslMode::Allow => "allow", + SslMode::Prefer => "prefer", + SslMode::Require => "require", + SslMode::VerifyCA => "verify-ca", + SslMode::VerifyFull => "verify-full", + }; + f.write_str(s) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum SslVersion { + Tls1, + Tls1_1, + Tls1_2, + Tls1_3, +} + +impl std::fmt::Display for SslVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + SslVersion::Tls1 => "TLSv1", + SslVersion::Tls1_1 => "TLSv1.1", + SslVersion::Tls1_2 => "TLSv1.2", + SslVersion::Tls1_3 => "TLSv1.3", + }; + f.write_str(s) + } +} + +impl serde::Serialize for SslVersion { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(match self { + SslVersion::Tls1 => "TLSv1", + SslVersion::Tls1_1 => "TLSv1.1", + SslVersion::Tls1_2 => "TLSv1.2", + SslVersion::Tls1_3 => "TLSv1.3", + }) + } +} + +impl<'a> TryFrom> for SslVersion { + type Error = ParseError; + fn try_from(value: Cow) -> Result { + Ok(match value.to_lowercase().as_ref() { + "tls_1" | "tlsv1" => SslVersion::Tls1, + "tls_1.1" | "tlsv1.1" => SslVersion::Tls1_1, + "tls_1.2" | "tlsv1.2" => SslVersion::Tls1_2, + "tls_1.3" | "tlsv1.3" => SslVersion::Tls1_3, + _ => return Err(ParseError::InvalidTLSVersion(value.to_string())), + }) + } +} + +impl From for openssl::ssl::SslVersion { + fn from(val: SslVersion) -> Self { + match val { + SslVersion::Tls1 => openssl::ssl::SslVersion::TLS1, + SslVersion::Tls1_1 => openssl::ssl::SslVersion::TLS1_1, + SslVersion::Tls1_2 => openssl::ssl::SslVersion::TLS1_2, + SslVersion::Tls1_3 => openssl::ssl::SslVersion::TLS1_3, + } + } +} + +fn parse_host_param(value: &str) -> Result>, ParseError> { + value + .split(',') + .map(|host| { + if host.is_empty() { + Ok(None) + } else { + HostType::try_from_str(host).map(Some) + } + }) + .collect() +} + +fn parse_port_param(port: &str) -> Result>, ParseError> { + port.split(',') + .map(|port| { + (!port.is_empty()) + .then(|| str::parse::(port)) + .transpose() + }) + .collect::>, _>>() + .map_err(|_| ParseError::InvalidPort(port.to_string())) +} + +fn parse_connect_timeout(timeout: Cow) -> Result { + let seconds = timeout.parse::().map_err(|_| { + ParseError::InvalidParameter("connect_timeout".to_string(), timeout.to_string()) + })?; + Ok(seconds) +} diff --git a/edb/server/pgrust/src/connection/state_machine.rs b/edb/server/pgrust/src/connection/state_machine.rs new file mode 100644 index 000000000000..b9fe9f187568 --- /dev/null +++ b/edb/server/pgrust/src/connection/state_machine.rs @@ -0,0 +1,417 @@ +use std::collections::HashMap; + +use super::{invalid_state, ConnectionError, Credentials, ServerErrorField}; +use crate::{ + auth::{self, generate_salted_password, ClientEnvironment, ClientTransaction, Sha256Out}, + connection::SslError, + protocol::{ + builder, + definition::{FrontendBuilder, InitialBuilder}, + match_message, AuthenticationCleartextPassword, AuthenticationMD5Password, + AuthenticationMessage, AuthenticationOk, AuthenticationSASL, AuthenticationSASLContinue, + AuthenticationSASLFinal, BackendKeyData, ErrorResponse, Message, ParameterStatus, + ReadyForQuery, SSLResponse, + }, +}; +use base64::Engine; +use rand::Rng; +use tracing::{trace, warn}; + +#[derive(Debug)] +struct ClientEnvironmentImpl { + credentials: Credentials, +} + +impl ClientEnvironment for ClientEnvironmentImpl { + fn generate_nonce(&self) -> String { + let nonce: [u8; 32] = rand::thread_rng().r#gen(); + base64::engine::general_purpose::STANDARD.encode(nonce) + } + fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out { + generate_salted_password(self.credentials.password.as_bytes(), salt, iterations) + } +} + +#[derive(Debug)] +enum ConnectionStateImpl { + /// Uninitialized connection state. Requires an initialization message to + /// start. + SslInitializing(Credentials, ConnectionSslRequirement), + /// SSL upgrade message was sent, awaiting server response. + SslWaiting(Credentials, ConnectionSslRequirement), + /// SSL upgrade in progress, waiting for handshake to complete. + SslConnecting(Credentials), + /// Uninitialized connection state. Requires an initialization message to + /// start. + Initializing(Credentials), + /// The initial connection string has been sent and we are waiting for an + /// auth response. + Connecting(Credentials), + /// The server has requested SCRAM auth. This holds a sub-state-machine that + /// manages a SCRAM challenge. + Scram(ClientTransaction, ClientEnvironmentImpl), + /// The authentication is successful and we are synchronizing server + /// parameters. + Connected, + /// The server is ready for queries. + Ready, + /// The connection failed. + Error, +} + +#[derive(Clone, Copy, Debug)] +pub enum ConnectionStateType { + Connecting, + SslConnecting, + Authenticating, + Synchronizing, + Ready, +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)] +pub enum Authentication { + #[default] + None, + Password, + Md5, + ScramSha256, +} + +#[derive(Debug)] +pub enum ConnectionDrive<'a> { + Initial, + Message(Message<'a>), + SslResponse(SSLResponse<'a>), + SslReady, +} + +impl<'a> ConnectionDrive<'a> { + pub fn message(&self) -> Result<&Message<'a>, ConnectionError> { + match self { + ConnectionDrive::Message(msg) => Ok(msg), + _ => Err(invalid_state!( + "Expected Message variant, but got a different ConnectionDrive variant" + )), + } + } +} + +pub trait ConnectionStateSend { + fn send_initial(&mut self, message: InitialBuilder) -> Result<(), std::io::Error>; + fn send(&mut self, message: FrontendBuilder) -> Result<(), std::io::Error>; + fn upgrade(&mut self) -> Result<(), std::io::Error>; +} + +/// A callback for connection state changes. +#[allow(unused)] +pub trait ConnectionStateUpdate: ConnectionStateSend { + fn parameter(&mut self, name: &str, value: &str) {} + fn cancellation_key(&mut self, pid: i32, key: i32) {} + fn state_changed(&mut self, state: ConnectionStateType) {} + fn server_error(&mut self, error: &ErrorResponse) {} + fn auth(&mut self, auth: Authentication) {} +} + +#[derive(Clone, Copy, PartialEq, Eq, Default, Debug)] +pub enum ConnectionSslRequirement { + #[default] + Disable, + Optional, + Required, +} + +/// ASCII state diagram for the connection state machine +/// +/// ```mermaid +/// stateDiagram-v2 +/// [*] --> SslInitializing: SSL not disabled +/// [*] --> Initializing: SSL disabled +/// SslInitializing --> SslWaiting: Send SSL request +/// SslWaiting --> SslConnecting: SSL accepted +/// SslWaiting --> Connecting: SSL rejected (if not required) +/// SslConnecting --> Connecting: SSL handshake complete +/// Initializing --> Connecting: Send startup message +/// Connecting --> Connected: Authentication successful +/// Connecting --> Scram: SCRAM auth requested +/// Scram --> Connected: SCRAM auth successful +/// Connected --> Ready: Parameter sync complete +/// Ready --> [*]: Connection closed +/// state Error { +/// [*] --> [*]: Any state can transition to Error +/// } +/// ``` +/// +/// The state machine for a Postgres connection. The state machine is driven +/// with calls to [`Self::drive`]. +pub struct ConnectionState(ConnectionStateImpl); + +impl ConnectionState { + pub fn new(credentials: Credentials, ssl_mode: ConnectionSslRequirement) -> Self { + if ssl_mode == ConnectionSslRequirement::Disable { + Self(ConnectionStateImpl::Initializing(credentials)) + } else { + Self(ConnectionStateImpl::SslInitializing(credentials, ssl_mode)) + } + } + + pub fn is_ready(&self) -> bool { + matches!(self.0, ConnectionStateImpl::Ready) + } + + pub fn read_ssl_response(&self) -> bool { + matches!(self.0, ConnectionStateImpl::SslWaiting(..)) + } + + pub fn drive( + &mut self, + drive: ConnectionDrive, + update: &mut impl ConnectionStateUpdate, + ) -> Result<(), ConnectionError> { + use ConnectionStateImpl::*; + let state = &mut self.0; + trace!("Received drive {drive:?} in state {state:?}"); + match state { + SslInitializing(credentials, mode) => { + if !matches!(drive, ConnectionDrive::Initial) { + return Err(invalid_state!( + "Expected Initial drive for SslInitializing state" + )); + } + update.send_initial(InitialBuilder::SSLRequest(builder::SSLRequest::default()))?; + *state = SslWaiting(std::mem::take(credentials), *mode); + update.state_changed(ConnectionStateType::Connecting); + } + SslWaiting(credentials, mode) => { + let ConnectionDrive::SslResponse(response) = drive else { + return Err(invalid_state!( + "Expected SslResponse drive for SslWaiting state" + )); + }; + + if *mode == ConnectionSslRequirement::Disable { + // Should not be possible + return Err(invalid_state!("SSL mode is Disable in SslWaiting state")); + } + + if response.code() == b'S' { + // Accepted + update.upgrade()?; + *state = SslConnecting(std::mem::take(credentials)); + update.state_changed(ConnectionStateType::SslConnecting); + } else if response.code() == b'N' { + // Rejected + if *mode == ConnectionSslRequirement::Required { + return Err(ConnectionError::SslError(SslError::SslRequiredByClient)); + } + Self::send_startup_message(credentials, update)?; + *state = Connecting(std::mem::take(credentials)); + } else { + return Err(ConnectionError::UnexpectedServerResponse(format!( + "Unexpected SSL response from server: {:?}", + response.code() as char + ))); + } + } + SslConnecting(credentials) => { + let ConnectionDrive::SslReady = drive else { + return Err(invalid_state!( + "Expected SslReady drive for SslConnecting state" + )); + }; + Self::send_startup_message(credentials, update)?; + *state = Connecting(std::mem::take(credentials)); + } + Initializing(credentials) => { + if !matches!(drive, ConnectionDrive::Initial) { + return Err(invalid_state!( + "Expected Initial drive for Initializing state" + )); + } + Self::send_startup_message(credentials, update)?; + *state = Connecting(std::mem::take(credentials)); + update.state_changed(ConnectionStateType::Connecting); + } + Connecting(credentials) => { + match_message!(drive.message()?, Backend { + (AuthenticationOk) => { + trace!("auth ok"); + *state = Connected; + update.state_changed(ConnectionStateType::Synchronizing); + }, + (AuthenticationSASL as sasl) => { + let mut found_scram_sha256 = false; + for mech in sasl.mechanisms() { + trace!("auth sasl: {:?}", mech); + if mech == "SCRAM-SHA-256" { + found_scram_sha256 = true; + break; + } + } + if !found_scram_sha256 { + return Err(ConnectionError::UnexpectedServerResponse("Server requested SASL authentication but does not support SCRAM-SHA-256".into())); + } + let credentials = credentials.clone(); + let mut tx = ClientTransaction::new("".into()); + let env = ClientEnvironmentImpl { credentials }; + let Some(initial_message) = tx.process_message(&[], &env)? else { + return Err(auth::SCRAMError::ProtocolError.into()); + }; + update.auth(Authentication::ScramSha256); + update.send(builder::SASLInitialResponse { + mechanism: "SCRAM-SHA-256", + response: &initial_message, + }.into())?; + *state = Scram(tx, env); + update.state_changed(ConnectionStateType::Authenticating); + }, + (AuthenticationMD5Password as md5) => { + trace!("auth md5"); + let md5_hash = auth::md5_password(&credentials.password, &credentials.username, &md5.salt()); + update.auth(Authentication::Md5); + update.send(builder::PasswordMessage { + password: &md5_hash, + }.into())?; + }, + (AuthenticationCleartextPassword) => { + trace!("auth cleartext"); + update.auth(Authentication::Password); + update.send(builder::PasswordMessage { + password: &credentials.password, + }.into())?; + }, + (ErrorResponse as error) => { + *state = Error; + update.server_error(&error); + return Err(error_to_server_error(error)); + }, + message => { + log_unknown_message(message, "Connecting") + }, + }); + } + Scram(tx, env) => { + match_message!(drive.message()?, Backend { + (AuthenticationSASLContinue as sasl) => { + let Some(message) = tx.process_message(&sasl.data(), env)? else { + return Err(auth::SCRAMError::ProtocolError.into()); + }; + update.send(builder::SASLResponse { + response: &message, + }.into())?; + }, + (AuthenticationSASLFinal as sasl) => { + let None = tx.process_message(&sasl.data(), env)? else { + return Err(auth::SCRAMError::ProtocolError.into()); + }; + }, + (AuthenticationOk) => { + trace!("auth ok"); + *state = Connected; + update.state_changed(ConnectionStateType::Synchronizing); + }, + (AuthenticationMessage as auth) => { + trace!("SCRAM Unknown auth message: {}", auth.status()) + }, + (ErrorResponse as error) => { + *state = Error; + update.server_error(&error); + return Err(error_to_server_error(error)); + }, + message => { + log_unknown_message(message, "SCRAM") + }, + }); + } + Connected => { + match_message!(drive.message()?, Backend { + (ParameterStatus as param) => { + trace!("param: {:?}={:?}", param.name(), param.value()); + update.parameter(param.name().try_into()?, param.value().try_into()?); + }, + (BackendKeyData as key_data) => { + trace!("key={:?} pid={:?}", key_data.key(), key_data.pid()); + update.cancellation_key(key_data.pid(), key_data.key()); + }, + (ReadyForQuery as ready) => { + trace!("ready: {:?}", ready.status() as char); + trace!("-> Ready"); + *state = Ready; + update.state_changed(ConnectionStateType::Ready); + }, + (ErrorResponse as error) => { + *state = Error; + update.server_error(&error); + return Err(error_to_server_error(error)); + }, + message => { + log_unknown_message(message, "Connected") + }, + }); + } + Ready | Error => { + return Err(invalid_state!("Unexpected drive for Ready or Error state")) + } + } + Ok(()) + } + + fn send_startup_message( + credentials: &Credentials, + update: &mut impl ConnectionStateUpdate, + ) -> Result<(), std::io::Error> { + let mut params = vec![ + builder::StartupNameValue { + name: "user", + value: &credentials.username, + }, + builder::StartupNameValue { + name: "database", + value: &credentials.database, + }, + ]; + for (name, value) in &credentials.server_settings { + params.push(builder::StartupNameValue { name, value }) + } + + update.send_initial(InitialBuilder::StartupMessage(builder::StartupMessage { + params: ¶ms, + })) + } +} + +fn log_unknown_message(message: &Message, state: &str) { + warn!( + "Unexpected message {:?} (length {}) received in {} state", + message.mtype(), + message.mlen(), + state + ); +} + +fn error_to_server_error(error: ErrorResponse) -> ConnectionError { + let mut code = String::new(); + let mut message = String::new(); + let mut extra = HashMap::new(); + + for field in error.fields() { + let value = field.value().to_string_lossy().into_owned(); + match ServerErrorField::try_from(field.etype()) { + Ok(ServerErrorField::Code) => code = value, + Ok(ServerErrorField::Message) => message = value, + Ok(field_type) => { + extra.insert(field_type, value); + } + Err(_) => warn!( + "Unxpected server error field: {:?} ({:?})", + field.etype() as char, + value + ), + } + } + + ConnectionError::ServerError { + code, + message, + extra, + } +} diff --git a/edb/server/pgrust/src/connection/stream.rs b/edb/server/pgrust/src/connection/stream.rs new file mode 100644 index 000000000000..6606a083c933 --- /dev/null +++ b/edb/server/pgrust/src/connection/stream.rs @@ -0,0 +1,195 @@ +use super::{invalid_state, ConnectionError, SslError}; +use std::pin::Pin; + +pub trait Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin {} +impl Stream for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin {} + +/// A trait for streams that can be upgraded to a secure connection. +/// +/// This trait is usually implemented by tuples that represent a connection that can be +/// upgraded from an insecure to a secure state, typically through SSL/TLS. +pub trait StreamWithUpgrade: Unpin { + type Base: Stream; + type Upgrade: Stream; + type Config: Unpin; + + /// Perform a secure upgrade operation and return the new, wrapped connection. + #[allow(async_fn_in_trait)] + async fn secure_upgrade(self) -> Result + where + Self: Sized; +} + +impl StreamWithUpgrade for (S, ()) { + type Base = S; + type Upgrade = S; + type Config = (); + + async fn secure_upgrade(self) -> Result + where + Self: Sized, + { + Err(ConnectionError::SslError(SslError::SslUnsupportedByClient)) + } +} + +pub struct UpgradableStream +where + (B, C): StreamWithUpgrade, +{ + inner: UpgradableStreamInner, +} + +impl From<(B, C)> for UpgradableStream +where + (B, C): StreamWithUpgrade, +{ + #[inline(always)] + fn from(value: (B, C)) -> Self { + Self::new(value.0, value.1) + } +} + +impl UpgradableStream +where + (B, C): StreamWithUpgrade, +{ + #[inline(always)] + pub fn new(base: B, config: C) -> Self { + UpgradableStream { + inner: UpgradableStreamInner::Base(base, config), + } + } + + pub async fn secure_upgrade(&mut self) -> Result<(), ConnectionError> + where + (B, C): StreamWithUpgrade, + { + match std::mem::replace(&mut self.inner, UpgradableStreamInner::Upgrading) { + UpgradableStreamInner::Base(base, config) => { + self.inner = + UpgradableStreamInner::Upgraded((base, config).secure_upgrade().await?); + Ok(()) + } + UpgradableStreamInner::Upgraded(..) => Err(invalid_state!( + "Attempted to upgrade an already upgraded stream" + )), + UpgradableStreamInner::Upgrading => Err(invalid_state!( + "Attempted to upgrade a stream that is already in the process of upgrading" + )), + } + } +} + +impl tokio::io::AsyncRead for UpgradableStream +where + (B, C): StreamWithUpgrade, +{ + #[inline(always)] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let inner = &mut self.get_mut().inner; + match inner { + UpgradableStreamInner::Base(base, _) => Pin::new(base).poll_read(cx, buf), + UpgradableStreamInner::Upgraded(upgraded) => Pin::new(upgraded).poll_read(cx, buf), + UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Cannot read while upgrading", + ))), + } + } +} + +impl tokio::io::AsyncWrite for UpgradableStream +where + (B, C): StreamWithUpgrade, +{ + #[inline(always)] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let inner = &mut self.get_mut().inner; + match inner { + UpgradableStreamInner::Base(base, _) => Pin::new(base).poll_write(cx, buf), + UpgradableStreamInner::Upgraded(upgraded) => Pin::new(upgraded).poll_write(cx, buf), + UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Cannot write while upgrading", + ))), + } + } + + #[inline(always)] + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let inner = &mut self.get_mut().inner; + match inner { + UpgradableStreamInner::Base(base, _) => Pin::new(base).poll_flush(cx), + UpgradableStreamInner::Upgraded(upgraded) => Pin::new(upgraded).poll_flush(cx), + UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Cannot flush while upgrading", + ))), + } + } + + #[inline(always)] + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let inner = &mut self.get_mut().inner; + match inner { + UpgradableStreamInner::Base(base, _) => Pin::new(base).poll_shutdown(cx), + UpgradableStreamInner::Upgraded(upgraded) => Pin::new(upgraded).poll_shutdown(cx), + UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Cannot shutdown while upgrading", + ))), + } + } + + #[inline(always)] + fn is_write_vectored(&self) -> bool { + match &self.inner { + UpgradableStreamInner::Base(base, _) => base.is_write_vectored(), + UpgradableStreamInner::Upgraded(upgraded) => upgraded.is_write_vectored(), + UpgradableStreamInner::Upgrading => false, + } + } + + #[inline(always)] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + let inner = &mut self.get_mut().inner; + match inner { + UpgradableStreamInner::Base(base, _) => Pin::new(base).poll_write_vectored(cx, bufs), + UpgradableStreamInner::Upgraded(upgraded) => { + Pin::new(upgraded).poll_write_vectored(cx, bufs) + } + UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Cannot write vectored while upgrading", + ))), + } + } +} + +enum UpgradableStreamInner +where + (B, C): StreamWithUpgrade, +{ + Base(B, C), + Upgraded(<(B, C) as StreamWithUpgrade>::Upgrade), + Upgrading, +} diff --git a/edb/server/pgrust/src/connection/tokio.rs b/edb/server/pgrust/src/connection/tokio.rs new file mode 100644 index 000000000000..a820ce4b5788 --- /dev/null +++ b/edb/server/pgrust/src/connection/tokio.rs @@ -0,0 +1,136 @@ +//! This module provides functionality to connect to Tokio TCP and Unix sockets. + +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; + +use derive_more::From; +use std::net::SocketAddr; +use std::path::{Path, PathBuf}; + +/// Represents a socket address for Tokio connections, supporting both TCP and Unix sockets. +#[derive(From, Debug)] +pub enum TokioSocketAddress { + /// TCP socket address + Tcp(SocketAddr), + /// Unix socket address (only available on Unix systems) + #[cfg(unix)] + Unix(PathBuf), +} + +impl TokioSocketAddress { + /// Creates a new TCP socket address + #[inline(always)] + pub fn new_tcp(addr: SocketAddr) -> Self { + TokioSocketAddress::Tcp(addr) + } + + /// Creates a new Unix socket address (only available on Unix systems) + #[inline(always)] + #[cfg(unix)] + pub fn new_unix>(path: P) -> Self { + TokioSocketAddress::Unix(path.as_ref().into()) + } +} + +impl TokioSocketAddress { + /// Connects to the socket address and returns a TokioStream + pub async fn connect(&self) -> std::io::Result { + match self { + TokioSocketAddress::Tcp(addr) => { + let stream = TcpStream::connect(addr).await?; + Ok(TokioStream::Tcp(stream)) + } + #[cfg(unix)] + TokioSocketAddress::Unix(path) => { + let stream = UnixStream::connect(path).await?; + Ok(TokioStream::Unix(stream)) + } + } + } +} + +/// Represents a connected Tokio stream, either TCP or Unix +pub enum TokioStream { + /// TCP stream + Tcp(TcpStream), + /// Unix stream (only available on Unix systems) + #[cfg(unix)] + Unix(UnixStream), +} + +impl AsyncRead for TokioStream { + #[inline(always)] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + TokioStream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf), + #[cfg(unix)] + TokioStream::Unix(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TokioStream { + #[inline(always)] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + TokioStream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf), + #[cfg(unix)] + TokioStream::Unix(stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + #[inline(always)] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + TokioStream::Tcp(stream) => Pin::new(stream).poll_flush(cx), + #[cfg(unix)] + TokioStream::Unix(stream) => Pin::new(stream).poll_flush(cx), + } + } + + #[inline(always)] + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.get_mut() { + TokioStream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx), + #[cfg(unix)] + TokioStream::Unix(stream) => Pin::new(stream).poll_shutdown(cx), + } + } + + #[inline(always)] + fn is_write_vectored(&self) -> bool { + match self { + TokioStream::Tcp(stream) => stream.is_write_vectored(), + #[cfg(unix)] + TokioStream::Unix(stream) => stream.is_write_vectored(), + } + } + + #[inline(always)] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + TokioStream::Tcp(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), + #[cfg(unix)] + TokioStream::Unix(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), + } + } +} diff --git a/edb/server/pgrust/src/lib.rs b/edb/server/pgrust/src/lib.rs index aadc8004ba43..9209e8feb014 100644 --- a/edb/server/pgrust/src/lib.rs +++ b/edb/server/pgrust/src/lib.rs @@ -1,6 +1,6 @@ -mod conn_string; - -pub use conn_string::{parse_postgres_url, Host, ParseError}; +pub mod auth; +pub mod connection; +pub mod protocol; #[cfg(feature = "python_extension")] -mod python; +pub mod python; diff --git a/edb/server/pgrust/src/protocol/arrays.rs b/edb/server/pgrust/src/protocol/arrays.rs new file mode 100644 index 000000000000..91cb34b64f91 --- /dev/null +++ b/edb/server/pgrust/src/protocol/arrays.rs @@ -0,0 +1,346 @@ +#![allow(private_bounds)] +use super::{Enliven, FieldAccessArray, FixedSize, Meta, MetaRelation}; +use std::fmt::Write; +pub use std::marker::PhantomData; + +pub mod meta { + pub use super::ArrayMeta as Array; + pub use super::ZTArrayMeta as ZTArray; +} + +/// Inflated version of a zero-terminated array with zero-copy iterator access. +pub struct ZTArray<'a, T: FieldAccessArray> { + _phantom: PhantomData, + buf: &'a [u8], +} + +/// Metaclass for [`ZTArray`]. +pub struct ZTArrayMeta { + pub(crate) _phantom: PhantomData, +} +impl Meta for ZTArrayMeta { + fn name(&self) -> &'static str { + "ZTArray" + } + fn relations(&self) -> &'static [(MetaRelation, &'static dyn Meta)] { + &[(MetaRelation::Item, ::META)] + } +} + +impl Enliven for ZTArrayMeta +where + T: FieldAccessArray, +{ + type WithLifetime<'a> = ZTArray<'a, T>; + type ForMeasure<'a> = &'a [::ForMeasure<'a>]; + type ForBuilder<'a> = &'a [::ForBuilder<'a>]; +} + +impl<'a, T: FieldAccessArray> ZTArray<'a, T> { + pub const fn new(buf: &'a [u8]) -> Self { + Self { + buf, + _phantom: PhantomData, + } + } +} + +/// [`ZTArray`] [`Iterator`] for values of type `T`. +pub struct ZTArrayIter<'a, T: FieldAccessArray> { + _phantom: PhantomData, + buf: &'a [u8], +} + +impl<'a, T> std::fmt::Debug for ZTArray<'a, T> +where + T: FieldAccessArray, + ::WithLifetime<'a>: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_char('[')?; + for it in self { + it.fmt(f)?; + f.write_str(", ")?; + } + f.write_char(']')?; + Ok(()) + } +} + +impl<'a, T: FieldAccessArray> IntoIterator for ZTArray<'a, T> { + type Item = ::WithLifetime<'a>; + type IntoIter = ZTArrayIter<'a, T>; + fn into_iter(self) -> Self::IntoIter { + ZTArrayIter { + _phantom: PhantomData, + buf: self.buf, + } + } +} + +impl<'a, T: FieldAccessArray> IntoIterator for &ZTArray<'a, T> { + type Item = ::WithLifetime<'a>; + type IntoIter = ZTArrayIter<'a, T>; + fn into_iter(self) -> Self::IntoIter { + ZTArrayIter { + _phantom: PhantomData, + buf: self.buf, + } + } +} + +impl<'a, T: FieldAccessArray> Iterator for ZTArrayIter<'a, T> { + type Item = ::WithLifetime<'a>; + fn next(&mut self) -> Option { + if self.buf[0] == 0 { + return None; + } + let (value, buf) = self.buf.split_at(T::size_of_field_at(self.buf)); + self.buf = buf; + Some(T::extract(value)) + } +} + +/// Inflated version of a length-specified array with zero-copy iterator access. +pub struct Array<'a, L, T: FieldAccessArray> { + _phantom: PhantomData<(L, T)>, + buf: &'a [u8], + len: u32, +} + +/// Metaclass for [`Array`]. +pub struct ArrayMeta { + pub(crate) _phantom: PhantomData<(L, T)>, +} + +impl Meta for ArrayMeta { + fn name(&self) -> &'static str { + "Array" + } + fn relations(&self) -> &'static [(MetaRelation, &'static dyn Meta)] { + &[ + (MetaRelation::Length, L::META), + (MetaRelation::Item, T::META), + ] + } +} + +impl Enliven for ArrayMeta +where + T: FieldAccessArray, +{ + type WithLifetime<'a> = Array<'a, L, T>; + type ForMeasure<'a> = &'a [::ForMeasure<'a>]; + type ForBuilder<'a> = &'a [::ForBuilder<'a>]; +} + +impl<'a, L, T: FieldAccessArray> Array<'a, L, T> { + pub const fn new(buf: &'a [u8], len: u32) -> Self { + Self { + buf, + _phantom: PhantomData, + len, + } + } + + #[inline(always)] + pub const fn len(&self) -> usize { + self.len as usize + } + + #[inline(always)] + pub const fn is_empty(&self) -> bool { + self.len == 0 + } +} + +impl<'a, L, T> std::fmt::Debug for Array<'a, L, T> +where + T: FieldAccessArray, + ::WithLifetime<'a>: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_char('[')?; + for it in self { + it.fmt(f)?; + f.write_str(", ")?; + } + f.write_char(']')?; + Ok(()) + } +} + +/// [`Array`] [`Iterator`] for values of type `T`. +pub struct ArrayIter<'a, T: FieldAccessArray> { + _phantom: PhantomData, + buf: &'a [u8], + len: u32, +} + +impl<'a, L, T: FieldAccessArray> IntoIterator for Array<'a, L, T> { + type Item = ::WithLifetime<'a>; + type IntoIter = ArrayIter<'a, T>; + fn into_iter(self) -> Self::IntoIter { + ArrayIter { + _phantom: PhantomData, + buf: self.buf, + len: self.len, + } + } +} + +impl<'a, L, T: FieldAccessArray> IntoIterator for &Array<'a, L, T> { + type Item = ::WithLifetime<'a>; + type IntoIter = ArrayIter<'a, T>; + fn into_iter(self) -> Self::IntoIter { + ArrayIter { + _phantom: PhantomData, + buf: self.buf, + len: self.len, + } + } +} + +impl<'a, T: FieldAccessArray> Iterator for ArrayIter<'a, T> { + type Item = ::WithLifetime<'a>; + fn next(&mut self) -> Option { + if self.len == 0 { + return None; + } + self.len -= 1; + let len = T::size_of_field_at(self.buf); + let (value, buf) = self.buf.split_at(len); + self.buf = buf; + Some(T::extract(value)) + } +} + +/// Definate array accesses for inflated, strongly-typed arrays of both +/// zero-terminated and length-delimited types. +macro_rules! array_access { + ($ty:ty) => { + $crate::protocol::arrays::array_access!($ty | u8 i16 i32); + }; + ($ty:ty | $($len:ty)*) => { + $( + #[allow(unused)] + impl $crate::protocol::FieldAccess<$crate::protocol::meta::Array<$len, $ty>> { + pub const fn meta() -> &'static dyn $crate::protocol::Meta { + &$crate::protocol::meta::Array::<$len, $ty> { _phantom: std::marker::PhantomData } + } + #[inline] + pub const fn size_of_field_at(mut buf: &[u8]) -> usize { + let mut size = std::mem::size_of::<$len>(); + let mut len = $crate::protocol::FieldAccess::<$len>::extract(buf); + buf = buf.split_at(size).1; + loop { + if len == 0 { + break; + } + len -= 1; + let elem_size = $crate::protocol::FieldAccess::<$ty>::size_of_field_at(buf); + buf = buf.split_at(elem_size).1; + size += elem_size; + } + size + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> $crate::protocol::Array<'_, $len, $ty> { + let len = $crate::protocol::FieldAccess::<$len>::extract(buf); + $crate::protocol::Array::new(buf.split_at(std::mem::size_of::<$len>()).1, len as u32) + } + #[inline] + pub const fn measure<'a>(buffer: &'a[<$ty as $crate::protocol::Enliven>::ForMeasure<'a>]) -> usize { + let mut size = std::mem::size_of::<$len>(); + let mut index = 0; + loop { + if index + 1 > buffer.len() { + break; + } + let item = &buffer[index]; + size += $crate::protocol::FieldAccess::<$ty>::measure(item); + index += 1; + } + size + } + #[inline(always)] + pub fn copy_to_buf<'a>(buf: &mut $crate::protocol::writer::BufWriter, value: &'a[<$ty as $crate::protocol::Enliven>::ForBuilder<'a>]) { + buf.write(&<$len>::to_be_bytes(value.len() as _)); + for elem in value { + $crate::protocol::FieldAccess::<$ty>::copy_to_buf_ref(buf, elem); + } + } + + } + )* + + #[allow(unused)] + impl $crate::protocol::FieldAccess<$crate::protocol::meta::ZTArray<$ty>> { + pub const fn meta() -> &'static dyn $crate::protocol::Meta { + &$crate::protocol::meta::ZTArray::<$ty> { _phantom: std::marker::PhantomData } + } + #[inline] + pub const fn size_of_field_at(mut buf: &[u8]) -> usize { + let mut size = 1; + loop { + if buf[0] == 0 { + return size; + } + let elem_size = $crate::protocol::FieldAccess::<$ty>::size_of_field_at(buf); + buf = buf.split_at(elem_size).1; + size += elem_size; + } + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> $crate::protocol::ZTArray<$ty> { + $crate::protocol::ZTArray::new(buf) + } + #[inline] + pub const fn measure<'a>(buffer: &'a[<$ty as $crate::protocol::Enliven>::ForMeasure<'a>]) -> usize { + let mut size = 1; + let mut index = 0; + loop { + if index + 1 > buffer.len() { + break; + } + let item = &buffer[index]; + size += $crate::protocol::FieldAccess::<$ty>::measure(item); + index += 1; + } + size + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut $crate::protocol::writer::BufWriter, value: &[<$ty as $crate::protocol::Enliven>::ForBuilder<'_>]) { + for elem in value { + $crate::protocol::FieldAccess::<$ty>::copy_to_buf_ref(buf, elem); + } + buf.write_u8(0); + } + } + }; +} +pub(crate) use array_access; + +// Arrays of type [`u8`] are special-cased to return a slice of bytes. +impl AsRef<[u8]> for Array<'_, T, u8> { + fn as_ref(&self) -> &[u8] { + &self.buf[..self.len as _] + } +} + +// Arrays of fixed-size elements can extract elements in O(1). +impl<'a, L: TryInto, T: FixedSize + FieldAccessArray> Array<'a, L, T> { + #[allow(unused)] + fn get(&self, index: L) -> Option<::WithLifetime<'a>> { + let Ok(index) = index.try_into() else { + return None; + }; + let index: usize = index; + if index >= self.len as _ { + None + } else { + let segment = &self.buf[T::SIZE * index..T::SIZE * (index + 1)]; + Some(T::extract(segment)) + } + } +} diff --git a/edb/server/pgrust/src/protocol/buffer.rs b/edb/server/pgrust/src/protocol/buffer.rs new file mode 100644 index 000000000000..292b48997859 --- /dev/null +++ b/edb/server/pgrust/src/protocol/buffer.rs @@ -0,0 +1,224 @@ +use std::{collections::VecDeque, marker::PhantomData}; + +use super::StructLength; + +/// A buffer that accumulates bytes of sized structs and feeds them to provided sink function when messages +/// are complete. This buffer handles partial messages and multiple messages in a single push. +#[derive(Default)] +pub struct StructBuffer { + _phantom: PhantomData, + accum: VecDeque, +} + +impl StructBuffer { + /// Pushes bytes into the buffer, potentially feeding output to the function. + /// + /// # Lifetimes + /// - `'a`: The lifetime of the input byte slice. + /// - `'b`: The lifetime of the mutable reference to `self`. + /// - `'c`: A lifetime used in the closure's type, representing the lifetime of the `M::Struct` instances passed to it. + /// + /// The constraint `'a: 'b` ensures that the input bytes live at least as long as the mutable reference to `self`. + /// + /// The `for<'c>` syntax in the closure type is a higher-ranked trait bound. It indicates that the closure + /// must be able to handle `M::Struct` with any lifetime `'c`. This is crucial because: + /// + /// 1. It allows the `push` method to create `M::Struct` instances with lifetimes that are not known + /// at the time the closure is defined. + /// 2. It ensures that the `M::Struct` instances passed to the closure are only valid for the duration + /// of each call to the closure, not for the entire lifetime of the `push` method. + /// 3. It prevents the closure from storing or returning these `M::Struct` instances, as their lifetime + /// is limited to the scope of each closure invocation. + pub fn push<'a: 'b, 'b>( + &'b mut self, + bytes: &'a [u8], + mut f: impl for<'c> FnMut(M::Struct<'c>), + ) { + if self.accum.is_empty() { + // Fast path: try to process the input directly + let mut offset = 0; + while offset < bytes.len() { + if let Some(len) = M::length_of_buf(&bytes[offset..]) { + if offset + len <= bytes.len() { + f(M::new(&bytes[offset..offset + len])); + offset += len; + } else { + break; + } + } else { + break; + } + } + if offset == bytes.len() { + return; + } + self.accum.extend(&bytes[offset..]); + } else { + self.accum.extend(bytes); + } + + // Slow path: process accumulated data + let contiguous = self.accum.make_contiguous(); + let mut total_processed = 0; + while let Some(len) = M::length_of_buf(&contiguous[total_processed..]) { + if total_processed + len <= contiguous.len() { + let message_bytes = &contiguous[total_processed..total_processed + len]; + f(M::new(message_bytes)); + total_processed += len; + } else { + break; + } + } + if total_processed > 0 { + self.accum.rotate_left(total_processed); + self.accum.truncate(self.accum.len() - total_processed); + } + } + + /// Pushes bytes into the buffer, potentially feeding output to the function. + /// + /// # Lifetimes + /// - `'a`: The lifetime of the input byte slice. + /// - `'b`: The lifetime of the mutable reference to `self`. + /// - `'c`: A lifetime used in the closure's type, representing the lifetime of the `M::Struct` instances passed to it. + /// + /// The constraint `'a: 'b` ensures that the input bytes live at least as long as the mutable reference to `self`. + /// + /// The `for<'c>` syntax in the closure type is a higher-ranked trait bound. It indicates that the closure + /// must be able to handle `M::Struct` with any lifetime `'c`. This is crucial because: + /// + /// 1. It allows the `push` method to create `M::Struct` instances with lifetimes that are not known + /// at the time the closure is defined. + /// 2. It ensures that the `M::Struct` instances passed to the closure are only valid for the duration + /// of each call to the closure, not for the entire lifetime of the `push` method. + /// 3. It prevents the closure from storing or returning these `M::Struct` instances, as their lifetime + /// is limited to the scope of each closure invocation. + pub fn push_fallible<'a: 'b, 'b, E>( + &'b mut self, + bytes: &'a [u8], + mut f: impl for<'c> FnMut(M::Struct<'c>) -> Result<(), E>, + ) -> Result<(), E> { + if self.accum.is_empty() { + // Fast path: try to process the input directly + let mut offset = 0; + while offset < bytes.len() { + if let Some(len) = M::length_of_buf(&bytes[offset..]) { + if offset + len <= bytes.len() { + f(M::new(&bytes[offset..offset + len]))?; + offset += len; + } else { + break; + } + } else { + break; + } + } + if offset == bytes.len() { + return Ok(()); + } + self.accum.extend(&bytes[offset..]); + } else { + self.accum.extend(bytes); + } + + // Slow path: process accumulated data + let contiguous = self.accum.make_contiguous(); + let mut total_processed = 0; + while let Some(len) = M::length_of_buf(&contiguous[total_processed..]) { + if total_processed + len <= contiguous.len() { + let message_bytes = &contiguous[total_processed..total_processed + len]; + f(M::new(message_bytes))?; + total_processed += len; + } else { + break; + } + } + if total_processed > 0 { + self.accum.rotate_left(total_processed); + self.accum.truncate(self.accum.len() - total_processed); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::StructBuffer; + use crate::protocol::{builder, meta, Encoded, Message}; + + /// Create a test data buffer containing three messages + fn test_data() -> (Vec, Vec) { + let mut test_data = vec![]; + let mut lengths = vec![]; + test_data.append(&mut builder::Sync::default().to_vec()); + let len = test_data.len(); + lengths.push(len); + test_data.append(&mut builder::CommandComplete { tag: "TAG" }.to_vec()); + lengths.push(test_data.len() - len); + let len = test_data.len(); + test_data.append( + &mut builder::DataRow { + values: &[Encoded::Value(b"1")], + } + .to_vec(), + ); + lengths.push(test_data.len() - len); + (test_data, lengths) + } + + fn process_chunks(buf: &[u8], chunk_lengths: &[usize]) { + assert_eq!( + chunk_lengths.iter().sum::(), + buf.len(), + "Sum of chunk lengths must equal total buffer length" + ); + + let mut accumulated_messages: Vec> = Vec::new(); + let mut buffer = StructBuffer::::default(); + let mut f = |msg: Message| { + eprintln!("Message: {msg:?}"); + accumulated_messages.push(msg.to_vec()); + }; + + let mut start = 0; + for &length in chunk_lengths { + let end = start + length; + let chunk = &buf[start..end]; + eprintln!("Chunk: {chunk:?}"); + + buffer.push(chunk, &mut f); + start = end; + } + + assert_eq!(accumulated_messages.len(), 3); + + let mut out = vec![]; + for message in accumulated_messages { + out.append(&mut message.to_vec()); + } + + assert_eq!(&out, buf); + } + + #[test] + fn test_message_buffer_chunked() { + let (test_data, chunk_lengths) = test_data(); + process_chunks(&test_data, &chunk_lengths); + } + + #[test] + fn test_message_buffer_byte_by_byte() { + let (test_data, _) = test_data(); + let chunk_lengths: Vec = vec![1; test_data.len()]; + process_chunks(&test_data, &chunk_lengths); + } + + #[test] + fn test_message_buffer_incremental_chunks() { + let (test_data, _) = test_data(); + for i in 0..test_data.len() { + let chunk_lengths = vec![i, test_data.len() - i]; + process_chunks(&test_data, &chunk_lengths); + } + } +} diff --git a/edb/server/pgrust/src/protocol/datatypes.rs b/edb/server/pgrust/src/protocol/datatypes.rs new file mode 100644 index 000000000000..bf5fe0f0addd --- /dev/null +++ b/edb/server/pgrust/src/protocol/datatypes.rs @@ -0,0 +1,508 @@ +use std::{marker::PhantomData, str::Utf8Error}; + +use super::{ + arrays::{array_access, Array, ArrayMeta}, + field_access, + writer::BufWriter, + Enliven, FieldAccess, Meta, +}; + +pub mod meta { + pub use super::EncodedMeta as Encoded; + pub use super::LengthMeta as Length; + pub use super::RestMeta as Rest; + pub use super::ZTStringMeta as ZTString; +} + +/// Represents the remainder of data in a message. +#[derive(Debug, PartialEq, Eq)] +pub struct Rest<'a> { + buf: &'a [u8], +} + +field_access!(RestMeta); + +pub struct RestMeta {} +impl Meta for RestMeta { + fn name(&self) -> &'static str { + "Rest" + } +} +impl Enliven for RestMeta { + type WithLifetime<'a> = Rest<'a>; + type ForMeasure<'a> = &'a [u8]; + type ForBuilder<'a> = &'a [u8]; +} + +impl<'a> Rest<'a> {} + +impl<'a> AsRef<[u8]> for Rest<'a> { + fn as_ref(&self) -> &[u8] { + self.buf + } +} + +impl<'a> std::ops::Deref for Rest<'a> { + type Target = [u8]; + fn deref(&self) -> &Self::Target { + self.buf + } +} + +impl PartialEq<[u8]> for Rest<'_> { + fn eq(&self, other: &[u8]) -> bool { + self.buf == other + } +} + +impl PartialEq<&[u8; N]> for Rest<'_> { + fn eq(&self, other: &&[u8; N]) -> bool { + self.buf == *other + } +} + +impl PartialEq<&[u8]> for Rest<'_> { + fn eq(&self, other: &&[u8]) -> bool { + self.buf == *other + } +} + +impl FieldAccess { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + &RestMeta {} + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> usize { + buf.len() + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> Rest<'_> { + Rest { buf } + } + #[inline(always)] + pub const fn measure(buf: &[u8]) -> usize { + buf.len() + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut BufWriter, value: &[u8]) { + buf.write(value) + } +} + +/// A zero-terminated string. +#[allow(unused)] +pub struct ZTString<'a> { + buf: &'a [u8], +} + +field_access!(ZTStringMeta); +array_access!(ZTStringMeta); + +pub struct ZTStringMeta {} +impl Meta for ZTStringMeta { + fn name(&self) -> &'static str { + "ZTString" + } +} + +impl Enliven for ZTStringMeta { + type WithLifetime<'a> = ZTString<'a>; + type ForMeasure<'a> = &'a str; + type ForBuilder<'a> = &'a str; +} + +impl std::fmt::Debug for ZTString<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + String::from_utf8_lossy(self.buf).fmt(f) + } +} + +impl<'a> ZTString<'a> { + pub fn to_owned(&self) -> Result { + String::from_utf8(self.buf.to_owned()) + } + + pub fn to_str(&self) -> Result<&str, std::str::Utf8Error> { + std::str::from_utf8(self.buf) + } + + pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + String::from_utf8_lossy(self.buf) + } +} + +impl PartialEq for ZTString<'_> { + fn eq(&self, other: &Self) -> bool { + self.buf == other.buf + } +} +impl Eq for ZTString<'_> {} + +impl PartialEq for ZTString<'_> { + fn eq(&self, other: &str) -> bool { + self.buf == other.as_bytes() + } +} + +impl PartialEq<&str> for ZTString<'_> { + fn eq(&self, other: &&str) -> bool { + self.buf == other.as_bytes() + } +} + +impl<'a> TryInto<&'a str> for ZTString<'a> { + type Error = Utf8Error; + fn try_into(self) -> Result<&'a str, Self::Error> { + std::str::from_utf8(self.buf) + } +} + +impl FieldAccess { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + &ZTStringMeta {} + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> usize { + let mut i = 0; + loop { + if buf[i] == 0 { + return i + 1; + } + i += 1; + } + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> ZTString<'_> { + let buf = buf.split_at(buf.len() - 1).0; + ZTString { buf } + } + #[inline(always)] + pub const fn measure(buf: &str) -> usize { + buf.len() + 1 + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut BufWriter, value: &str) { + buf.write(value.as_bytes()); + buf.write_u8(0); + } + #[inline(always)] + pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &str) { + buf.write(value.as_bytes()); + buf.write_u8(0); + } +} + +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +/// An encoded row value. +pub enum Encoded<'a> { + #[default] + Null, + Value(&'a [u8]), +} + +impl<'a> AsRef> for Encoded<'a> { + fn as_ref(&self) -> &Encoded<'a> { + self + } +} + +field_access!(EncodedMeta); +array_access!(EncodedMeta); + +pub struct EncodedMeta {} +impl Meta for EncodedMeta { + fn name(&self) -> &'static str { + "Encoded" + } +} + +impl Enliven for EncodedMeta { + type WithLifetime<'a> = Encoded<'a>; + type ForMeasure<'a> = Encoded<'a>; + type ForBuilder<'a> = Encoded<'a>; +} + +impl<'a> Encoded<'a> {} + +impl PartialEq for Encoded<'_> { + fn eq(&self, other: &str) -> bool { + self == &Encoded::Value(other.as_bytes()) + } +} + +impl PartialEq<&str> for Encoded<'_> { + fn eq(&self, other: &&str) -> bool { + self == &Encoded::Value(other.as_bytes()) + } +} + +impl PartialEq<[u8]> for Encoded<'_> { + fn eq(&self, other: &[u8]) -> bool { + self == &Encoded::Value(other) + } +} + +impl PartialEq<&[u8]> for Encoded<'_> { + fn eq(&self, other: &&[u8]) -> bool { + self == &Encoded::Value(other) + } +} + +impl FieldAccess { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + &EncodedMeta {} + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> usize { + const N: usize = std::mem::size_of::(); + if let Some(len) = buf.first_chunk::() { + let len = i32::from_be_bytes(*len); + if len == -1 { + N + } else { + len as usize + N + } + } else { + panic!() + } + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> Encoded<'_> { + const N: usize = std::mem::size_of::(); + if let Some((len, array)) = buf.split_first_chunk::() { + let len = i32::from_be_bytes(*len); + if len == -1 { + Encoded::Null + } else { + Encoded::Value(array) + } + } else { + panic!() + } + } + #[inline(always)] + pub const fn measure(value: &Encoded) -> usize { + match value { + Encoded::Null => std::mem::size_of::(), + Encoded::Value(value) => value.len() + std::mem::size_of::(), + } + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut BufWriter, value: Encoded) { + Self::copy_to_buf_ref(buf, &value) + } + #[inline(always)] + pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &Encoded) { + match value { + Encoded::Null => buf.write(&[0xff, 0xff, 0xff, 0xff]), + Encoded::Value(value) => { + let len: i32 = value.len() as _; + buf.write(&len.to_be_bytes()); + } + } + } +} + +// We alias usize here. Note that if this causes trouble in the future we can +// probably work around this by adding a new "const value" function to +// FieldAccess. For now it works! +pub struct LengthMeta(#[allow(unused)] i32); +impl Enliven for LengthMeta { + type WithLifetime<'a> = usize; + type ForMeasure<'a> = usize; + type ForBuilder<'a> = usize; +} +impl Meta for LengthMeta { + fn name(&self) -> &'static str { + "len" + } +} + +impl FieldAccess { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + &LengthMeta(0) + } + #[inline(always)] + pub const fn constant(value: usize) -> LengthMeta { + LengthMeta(value as i32) + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> usize { + FieldAccess::::size_of_field_at(buf) + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> usize { + FieldAccess::::extract(buf) as _ + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut BufWriter, value: usize) { + FieldAccess::::copy_to_buf(buf, value as i32) + } + #[inline(always)] + pub fn copy_to_buf_rewind(buf: &mut BufWriter, rewind: usize, value: usize) { + FieldAccess::::copy_to_buf_rewind(buf, rewind, value as i32) + } +} + +macro_rules! basic_types { + ($($ty:ty)*) => { + $( + field_access!{$ty} + + impl Enliven for $ty { + type WithLifetime<'a> = $ty; + type ForMeasure<'a> = $ty; + type ForBuilder<'a> = $ty; + } + + impl Enliven for [$ty; S] { + type WithLifetime<'a> = [$ty; S]; + type ForMeasure<'a> = [$ty; S]; + type ForBuilder<'a> = [$ty; S]; + } + + #[allow(unused)] + impl FieldAccess<$ty> { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + struct Meta {} + impl $crate::protocol::Meta for Meta { + fn name(&self) -> &'static str { + stringify!($ty) + } + } + &Meta{} + } + #[inline(always)] + pub const fn constant(value: usize) -> $ty { + value as _ + } + #[inline(always)] + pub const fn size_of_field_at(_: &[u8]) -> usize { + std::mem::size_of::<$ty>() + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> $ty { + if let Some(bytes) = buf.first_chunk() { + <$ty>::from_be_bytes(*bytes) + } else { + panic!() + } + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut BufWriter, value: $ty) { + buf.write(&<$ty>::to_be_bytes(value)); + } + #[inline(always)] + pub fn copy_to_buf_rewind(buf: &mut BufWriter, rewind: usize, value: $ty) { + buf.write_rewind(rewind, &<$ty>::to_be_bytes(value)); + } + } + + #[allow(unused)] + impl FieldAccess<[$ty; S]> { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + struct Meta {} + impl $crate::protocol::Meta for Meta { + fn name(&self) -> &'static str { + // TODO: can we extract this constant? + concat!('[', stringify!($ty), "; ", "S") + } + } + &Meta{} + } + #[inline(always)] + pub const fn size_of_field_at(_buf: &[u8]) -> usize { + std::mem::size_of::<$ty>() * S + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> [$ty; S] { + let mut out: [$ty; S] = [0; S]; + let mut i = 0; + loop { + if i == S { + break; + } + (out[i], buf) = if let Some((bytes, rest)) = buf.split_first_chunk() { + (<$ty>::from_be_bytes(*bytes), rest) + } else { + panic!() + }; + i += 1; + } + out + } + #[inline(always)] + pub fn copy_to_buf(mut buf: &mut BufWriter, value: [$ty; S]) { + if !buf.test(std::mem::size_of::<$ty>() * S) { + return; + } + for n in value { + buf.write(&<$ty>::to_be_bytes(n)); + } + } + } + + impl $crate::protocol::FixedSize for $ty { + const SIZE: usize = std::mem::size_of::<$ty>(); + } + impl $crate::protocol::FixedSize for [$ty; S] { + const SIZE: usize = std::mem::size_of::<$ty>() * S; + } + + basic_types!(: array<$ty> u8 i16 i32); + )* + }; + + (: array<$ty:ty> $($len:ty)*) => { + $( + #[allow(unused)] + impl FieldAccess> { + pub const fn meta() -> &'static dyn Meta { + &ArrayMeta::<$len, $ty> { _phantom: PhantomData } + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> usize { + const N: usize = std::mem::size_of::<$ty>(); + const L: usize = std::mem::size_of::<$len>(); + if let Some(len) = buf.first_chunk::() { + (<$len>::from_be_bytes(*len) as usize * N + L) + } else { + panic!() + } + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> Array<$len, $ty> { + const N: usize = std::mem::size_of::<$ty>(); + const L: usize = std::mem::size_of::<$len>(); + if let Some((len, array)) = buf.split_first_chunk::() { + Array::new(array, <$len>::from_be_bytes(*len) as u32) + } else { + panic!() + } + } + #[inline(always)] + pub const fn measure(buffer: &[$ty]) -> usize { + buffer.len() * std::mem::size_of::<$ty>() + std::mem::size_of::<$len>() + } + #[inline(always)] + pub fn copy_to_buf(mut buf: &mut BufWriter, value: &[$ty]) { + let size: usize = std::mem::size_of::<$ty>() * value.len() + std::mem::size_of::<$len>(); + if !buf.test(size) { + return; + } + buf.write(&<$len>::to_be_bytes(value.len() as _)); + for n in value { + buf.write(&<$ty>::to_be_bytes(*n)); + } + } + } + )* + } +} +basic_types!(u8 i16 i32); diff --git a/edb/server/pgrust/src/protocol/definition.rs b/edb/server/pgrust/src/protocol/definition.rs new file mode 100644 index 000000000000..bef36d5aa651 --- /dev/null +++ b/edb/server/pgrust/src/protocol/definition.rs @@ -0,0 +1,740 @@ +use super::gen::protocol; +use super::message_group::message_group; +use crate::protocol::meta::*; + +message_group!( + /// The `Backend` message group contains messages sent from the backend to the frontend. + Backend: Message = [ + AuthenticationOk, + AuthenticationKerberosV5, + AuthenticationCleartextPassword, + AuthenticationMD5Password, + AuthenticationGSS, + AuthenticationGSSContinue, + AuthenticationSSPI, + AuthenticationSASL, + AuthenticationSASLContinue, + AuthenticationSASLFinal, + BackendKeyData, + BindComplete, + CloseComplete, + CommandComplete, + CopyData, + CopyDone, + CopyInResponse, + CopyOutResponse, + CopyBothResponse, + DataRow, + EmptyQueryResponse, + ErrorResponse, + FunctionCallResponse, + NegotiateProtocolVersion, + NoData, + NoticeResponse, + NotificationResponse, + ParameterDescription, + ParameterStatus, + ParseComplete, + PortalSuspended, + ReadyForQuery, + RowDescription + ] +); + +message_group!( + /// The `Frontend` message group contains messages sent from the frontend to the backend. + Frontend: Message = [ + Bind, + Close, + CopyData, + CopyDone, + CopyFail, + Describe, + Execute, + Flush, + FunctionCall, + GSSResponse, + Parse, + PasswordMessage, + Query, + SASLInitialResponse, + SASLResponse, + Sync, + Terminate + ] +); + +message_group!( + /// The `Initial` message group contains messages that are sent before the + /// normal message flow. + Initial: InitialMessage = [ + CancelRequest, + GSSENCRequest, + SSLRequest, + StartupMessage + ] +); + +protocol!( + +/// A generic base for all Postgres mtype/mlen-style messages. +struct Message { + /// Identifies the message. + mtype: u8, + /// Length of message contents in bytes, including self. + mlen: len, + /// Message contents. + data: Rest, +} + +/// A generic base for all initial Postgres messages. +struct InitialMessage { + /// Length of message contents in bytes, including self. + mlen: len, + /// The identifier for this initial message. + protocol_version: i32, + /// Message contents. + data: Rest +} + +/// The `AuthenticationMessage` struct is a base for all Postgres authentication messages. +struct AuthenticationMessage: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that the authentication was successful. + status: i32, +} + +/// The `AuthenticationOk` struct represents a message indicating successful authentication. +struct AuthenticationOk: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// Specifies that the authentication was successful. + status: i32 = 0, +} + +/// The `AuthenticationKerberosV5` struct represents a message indicating that Kerberos V5 authentication is required. +struct AuthenticationKerberosV5: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// Specifies that Kerberos V5 authentication is required. + status: i32 = 2, +} + +/// The `AuthenticationCleartextPassword` struct represents a message indicating that a cleartext password is required for authentication. +struct AuthenticationCleartextPassword: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// Specifies that a clear-text password is required. + status: i32 = 3, +} + +/// The `AuthenticationMD5Password` struct represents a message indicating that an MD5-encrypted password is required for authentication. +struct AuthenticationMD5Password: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 12, + /// Specifies that an MD5-encrypted password is required. + status: i32 = 5, + /// The salt to use when encrypting the password. + salt: [u8; 4], +} + +/// The `AuthenticationSCMCredential` struct represents a message indicating that an SCM credential is required for authentication. +struct AuthenticationSCMCredential: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 6, + /// Any data byte, which is ignored. + byte: u8 = 0, +} + +/// The `AuthenticationGSS` struct represents a message indicating that GSSAPI authentication is required. +struct AuthenticationGSS: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// Specifies that GSSAPI authentication is required. + status: i32 = 7, +} + +/// The `AuthenticationGSSContinue` struct represents a message indicating the continuation of GSSAPI authentication. +struct AuthenticationGSSContinue: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that this message contains GSSAPI or SSPI data. + status: i32 = 8, + /// GSSAPI or SSPI authentication data. + data: Rest, +} + +/// The `AuthenticationSSPI` struct represents a message indicating that SSPI authentication is required. +struct AuthenticationSSPI: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// Specifies that SSPI authentication is required. + status: i32 = 9, +} + +/// The `AuthenticationSASL` struct represents a message indicating that SASL authentication is required. +struct AuthenticationSASL: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that SASL authentication is required. + status: i32 = 10, + /// List of SASL authentication mechanisms, terminated by a zero byte. + mechanisms: ZTArray, +} + +/// The `AuthenticationSASLContinue` struct represents a message containing a SASL challenge during the authentication process. +struct AuthenticationSASLContinue: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that this message contains a SASL challenge. + status: i32 = 11, + /// SASL data, specific to the SASL mechanism being used. + data: Rest, +} + +/// The `AuthenticationSASLFinal` struct represents a message indicating the completion of SASL authentication. +struct AuthenticationSASLFinal: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that SASL authentication has completed. + status: i32 = 12, + /// SASL outcome "additional data", specific to the SASL mechanism being used. + data: Rest, +} + +/// The `BackendKeyData` struct represents a message containing the process ID and secret key for this backend. +struct BackendKeyData: Message { + /// Identifies the message as cancellation key data. + mtype: u8 = 'K', + /// Length of message contents in bytes, including self. + mlen: len = 12, + /// The process ID of this backend. + pid: i32, + /// The secret key of this backend. + key: i32, +} + +/// The `Bind` struct represents a message to bind a named portal to a prepared statement. +struct Bind: Message { + /// Identifies the message as a Bind command. + mtype: u8 = 'B', + /// Length of message contents in bytes, including self. + mlen: len, + /// The name of the destination portal. + portal: ZTString, + /// The name of the source prepared statement. + statement: ZTString, + /// The parameter format codes. + format_codes: Array, + /// Array of parameter values and their lengths. + values: Array, + /// The result-column format codes. + result_format_codes: Array, +} + +/// The `BindComplete` struct represents a message indicating that a Bind operation was successful. +struct BindComplete: Message { + /// Identifies the message as a Bind-complete indicator. + mtype: u8 = '2', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `CancelRequest` struct represents a message to request the cancellation of a query. +struct CancelRequest: InitialMessage { + /// Length of message contents in bytes, including self. + mlen: len = 16, + /// The cancel request code. + code: i32 = 80877102, + /// The process ID of the target backend. + pid: i32, + /// The secret key for the target backend. + key: i32, +} + +/// The `Close` struct represents a message to close a prepared statement or portal. +struct Close: Message { + /// Identifies the message as a Close command. + mtype: u8 = 'C', + /// Length of message contents in bytes, including self. + mlen: len, + /// 'xS' to close a prepared statement; 'P' to close a portal. + ctype: u8, + /// The name of the prepared statement or portal to close. + name: ZTString, +} + +/// The `CloseComplete` struct represents a message indicating that a Close operation was successful. +struct CloseComplete: Message { + /// Identifies the message as a Close-complete indicator. + mtype: u8 = '3', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `CommandComplete` struct represents a message indicating the successful completion of a command. +struct CommandComplete: Message { + /// Identifies the message as a command-completed response. + mtype: u8 = 'C', + /// Length of message contents in bytes, including self. + mlen: len, + /// The command tag. + tag: ZTString, +} + +/// The `CopyData` struct represents a message containing data for a copy operation. +struct CopyData: Message { + /// Identifies the message as COPY data. + mtype: u8 = 'd', + /// Length of message contents in bytes, including self. + mlen: len, + /// Data that forms part of a COPY data stream. + data: Rest, +} + +/// The `CopyDone` struct represents a message indicating that a copy operation is complete. +struct CopyDone: Message { + /// Identifies the message as a COPY-complete indicator. + mtype: u8 = 'c', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `CopyFail` struct represents a message indicating that a copy operation has failed. +struct CopyFail: Message { + /// Identifies the message as a COPY-failure indicator. + mtype: u8 = 'f', + /// Length of message contents in bytes, including self. + mlen: len, + /// An error message to report as the cause of failure. + error_msg: ZTString, +} + +/// The `CopyInResponse` struct represents a message indicating that the server is ready to receive data for a copy-in operation. +struct CopyInResponse: Message { + /// Identifies the message as a Start Copy In response. + mtype: u8 = 'G', + /// Length of message contents in bytes, including self. + mlen: len, + /// 0 for textual, 1 for binary. + format: u8, + /// The format codes for each column. + format_codes: Array, +} + +/// The `CopyOutResponse` struct represents a message indicating that the server is ready to send data for a copy-out operation. +struct CopyOutResponse: Message { + /// Identifies the message as a Start Copy Out response. + mtype: u8 = 'H', + /// Length of message contents in bytes, including self. + mlen: len, + /// 0 for textual, 1 for binary. + format: u8, + /// The format codes for each column. + format_codes: Array, +} + +/// The `CopyBothResponse` is used only for Streaming Replication. +struct CopyBothResponse: Message { + /// Identifies the message as a Start Copy Both response. + mtype: u8 = 'W', + /// Length of message contents in bytes, including self. + mlen: len, + /// 0 for textual, 1 for binary. + format: u8, + /// The format codes for each column. + format_codes: Array, +} + +/// The `DataRow` struct represents a message containing a row of data. +struct DataRow: Message { + /// Identifies the message as a data row. + mtype: u8 = 'D', + /// Length of message contents in bytes, including self. + mlen: len, + /// Array of column values and their lengths. + values: Array, +} + +/// The `Describe` struct represents a message to describe a prepared statement or portal. +struct Describe: Message { + /// Identifies the message as a Describe command. + mtype: u8 = 'D', + /// Length of message contents in bytes, including self. + mlen: len, + /// 'S' to describe a prepared statement; 'P' to describe a portal. + dtype: u8, + /// The name of the prepared statement or portal. + name: ZTString, +} + +/// The `EmptyQueryResponse` struct represents a message indicating that an empty query string was recognized. +struct EmptyQueryResponse: Message { + /// Identifies the message as a response to an empty query String. + mtype: u8 = 'I', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `ErrorResponse` struct represents a message indicating that an error has occurred. +struct ErrorResponse: Message { + /// Identifies the message as an error. + mtype: u8 = 'E', + /// Length of message contents in bytes, including self. + mlen: len, + /// Array of error fields and their values. + fields: ZTArray, +} + +/// The `ErrorField` struct represents a single error message within an `ErrorResponse`. +struct ErrorField { + /// A code identifying the field type. + etype: u8, + /// The field value. + value: ZTString, +} + +/// The `Execute` struct represents a message to execute a prepared statement or portal. +struct Execute: Message { + /// Identifies the message as an Execute command. + mtype: u8 = 'E', + /// Length of message contents in bytes, including self. + mlen: len, + /// The name of the portal to execute. + portal: ZTString, + /// Maximum number of rows to return. + max_rows: i32, +} + +/// The `Flush` struct represents a message to flush the backend's output buffer. +struct Flush: Message { + /// Identifies the message as a Flush command. + mtype: u8 = 'H', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `FunctionCall` struct represents a message to call a function. +struct FunctionCall: Message { + /// Identifies the message as a function call. + mtype: u8 = 'F', + /// Length of message contents in bytes, including self. + mlen: len, + /// OID of the function to execute. + function_id: i32, + /// The parameter format codes. + format_codes: Array, + /// Array of args and their lengths. + args: Array, + /// The format code for the result. + result_format_code: i16, +} + +/// The `FunctionCallResponse` struct represents a message containing the result of a function call. +struct FunctionCallResponse: Message { + /// Identifies the message as a function-call response. + mtype: u8 = 'V', + /// Length of message contents in bytes, including self. + mlen: len, + /// The function result value. + result: Encoded, +} + +/// The `GSSENCRequest` struct represents a message requesting GSSAPI encryption. +struct GSSENCRequest: InitialMessage { + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// The GSSAPI Encryption request code. + gssenc_request_code: i32 = 80877104, +} + +/// The `GSSResponse` struct represents a message containing a GSSAPI or SSPI response. +struct GSSResponse: Message { + /// Identifies the message as a GSSAPI or SSPI response. + mtype: u8 = 'p', + /// Length of message contents in bytes, including self. + mlen: len, + /// GSSAPI or SSPI authentication data. + data: Rest, +} + +/// The `NegotiateProtocolVersion` struct represents a message requesting protocol version negotiation. +struct NegotiateProtocolVersion: Message { + /// Identifies the message as a protocol version negotiation request. + mtype: u8 = 'v', + /// Length of message contents in bytes, including self. + mlen: len, + /// Newest minor protocol version supported by the server. + minor_version: i32, + /// List of protocol options not recognized. + options: Array, +} + +/// The `NoData` struct represents a message indicating that there is no data to return. +struct NoData: Message { + /// Identifies the message as a No Data indicator. + mtype: u8 = 'n', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `NoticeResponse` struct represents a message containing a notice. +struct NoticeResponse: Message { + /// Identifies the message as a notice. + mtype: u8 = 'N', + /// Length of message contents in bytes, including self. + mlen: len, + /// Array of notice fields and their values. + fields: ZTArray, +} + +/// The `NoticeField` struct represents a single error message within an `NoticeResponse`. +struct NoticeField: Message { + /// A code identifying the field type. + ntype: u8, + /// The field value. + value: ZTString, +} + +/// The `NotificationResponse` struct represents a message containing a notification from the backend. +struct NotificationResponse: Message { + /// Identifies the message as a notification. + mtype: u8 = 'A', + /// Length of message contents in bytes, including self. + mlen: len, + /// The process ID of the notifying backend. + pid: i32, + /// The name of the notification channel. + channel: ZTString, + /// The notification payload. + payload: ZTString, +} + +/// The `ParameterDescription` struct represents a message describing the parameters needed by a prepared statement. +struct ParameterDescription: Message { + /// Identifies the message as a parameter description. + mtype: u8 = 't', + /// Length of message contents in bytes, including self. + mlen: len, + /// OIDs of the parameter data types. + param_types: Array, +} + +/// The `ParameterStatus` struct represents a message containing the current status of a parameter. +struct ParameterStatus: Message { + /// Identifies the message as a runtime parameter status report. + mtype: u8 = 'S', + /// Length of message contents in bytes, including self. + mlen: len, + /// The name of the parameter. + name: ZTString, + /// The current value of the parameter. + value: ZTString, +} + +/// The `Parse` struct represents a message to parse a query string. +struct Parse: Message { + /// Identifies the message as a Parse command. + mtype: u8 = 'P', + /// Length of message contents in bytes, including self. + mlen: len, + /// The name of the destination prepared statement. + statement: ZTString, + /// The query String to be parsed. + query: ZTString, + /// OIDs of the parameter data types. + param_types: Array, +} + +/// The `ParseComplete` struct represents a message indicating that a Parse operation was successful. +struct ParseComplete: Message { + /// Identifies the message as a Parse-complete indicator. + mtype: u8 = '1', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `PasswordMessage` struct represents a message containing a password. +struct PasswordMessage: Message { + /// Identifies the message as a password response. + mtype: u8 = 'p', + /// Length of message contents in bytes, including self. + mlen: len, + /// The password (encrypted or plaintext, depending on context). + password: ZTString, +} + +/// The `PortalSuspended` struct represents a message indicating that a portal has been suspended. +struct PortalSuspended: Message { + /// Identifies the message as a portal-suspended indicator. + mtype: u8 = 's', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `Query` struct represents a message to execute a simple query. +struct Query: Message { + /// Identifies the message as a simple query command. + mtype: u8 = 'Q', + /// Length of message contents in bytes, including self. + mlen: len, + /// The query String to be executed. + query: ZTString, +} + +/// The `ReadyForQuery` struct represents a message indicating that the backend is ready for a new query. +struct ReadyForQuery: Message { + /// Identifies the message as a ready-for-query indicator. + mtype: u8 = 'Z', + /// Length of message contents in bytes, including self. + mlen: len = 5, + /// Current transaction status indicator. + status: u8, +} + +/// The `RowDescription` struct represents a message describing the rows that will be returned by a query. +struct RowDescription: Message { + /// Identifies the message as a row description. + mtype: u8 = 'T', + /// Length of message contents in bytes, including self. + mlen: len, + /// Array of field descriptions. + fields: Array, +} + +/// The `RowField` struct represents a row within the `RowDescription` message. +struct RowField { + /// The field name + name: ZTString, + /// The table ID (OID) of the table the column is from, or 0 if not a column reference + table_oid: i32, + /// The attribute number of the column, or 0 if not a column reference + column_attr_number: i16, + /// The object ID of the field's data type + data_type_oid: i32, + /// The data type size (negative if variable size) + data_type_size: i16, + /// The type modifier + type_modifier: i32, + /// The format code being used for the field (0 for text, 1 for binary) + format_code: i16, +} + +/// The `SASLInitialResponse` struct represents a message containing a SASL initial response. +struct SASLInitialResponse: Message { + /// Identifies the message as a SASL initial response. + mtype: u8 = 'p', + /// Length of message contents in bytes, including self. + mlen: len, + /// Name of the SASL authentication mechanism. + mechanism: ZTString, + /// SASL initial response data. + response: Array, +} + +/// The `SASLResponse` struct represents a message containing a SASL response. +struct SASLResponse: Message { + /// Identifies the message as a SASL response. + mtype: u8 = 'p', + /// Length of message contents in bytes, including self. + mlen: len, + /// SASL response data. + response: Rest, +} + +/// The `SSLRequest` struct represents a message requesting SSL encryption. +struct SSLRequest: InitialMessage { + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// The SSL request code. + code: i32 = 80877103, +} + +struct SSLResponse { + /// Specifies if SSL was accepted or rejected. + code: u8, +} + +/// The `StartupMessage` struct represents a message to initiate a connection. +struct StartupMessage: InitialMessage { + /// Length of message contents in bytes, including self. + mlen: len, + /// The protocol version number. + protocol: i32 = 196608, + /// List of parameter name-value pairs, terminated by a zero byte. + params: ZTArray, +} + +/// The `StartupMessage` struct represents a name/value pair within the `StartupMessage` message. +struct StartupNameValue { + /// The parameter name. + name: ZTString, + /// The parameter value. + value: ZTString, +} + +/// The `Sync` struct represents a message to synchronize the frontend and backend. +struct Sync: Message { + /// Identifies the message as a Sync command. + mtype: u8 = 'S', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `Terminate` struct represents a message to terminate a connection. +struct Terminate: Message { + /// Identifies the message as a Terminate command. + mtype: u8 = 'X', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} +); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_all() { + let message = meta::Message::default(); + let initial_message = meta::InitialMessage::default(); + + for meta in meta::ALL { + eprintln!("{meta:#?}"); + if **meta != message && **meta != initial_message { + if meta.field("mtype").is_some() && meta.field("mlen").is_some() { + // If a message has mtype and mlen, it should subclass Message + assert_eq!(*meta.parent().unwrap(), message); + } else if meta.field("mlen").is_some() { + // If a message has mlen only, it should subclass InitialMessage + assert_eq!(*meta.parent().unwrap(), initial_message); + } + } + } + } +} diff --git a/edb/server/pgrust/src/protocol/gen.rs b/edb/server/pgrust/src/protocol/gen.rs new file mode 100644 index 000000000000..d58171cbcb4b --- /dev/null +++ b/edb/server/pgrust/src/protocol/gen.rs @@ -0,0 +1,776 @@ +/// Performs a first-pass parse on a struct, filling out some additional +/// metadata that makes the jobs of further macro passes much simpler. +/// +/// This macro takes a `next` parameter which allows you to funnel the +/// structured data from the macro into the next macro. The complex parsing +/// happens in here using a "push-down automation" technique. +/// +/// The term "push-down automation" here refers to how metadata and parsed +/// information are "pushed down" through the macro’s recursive structure. Each +/// level of the macro adds its own layer of processing and metadata +/// accumulation, eventually leading to the final output. +/// +/// The `struct_elaborate!` macro is a tool designed to perform an initial +/// parsing pass on a Rust `struct`, enriching it with metadata to facilitate +/// further macro processing. It begins by extracting and analyzing the fields +/// of the `struct`, capturing associated metadata such as attributes and types. +/// This macro takes a `next` parameter, which is another macro to be invoked +/// after the current one completes its task, allowing for a seamless chaining +/// of macros where each one builds upon the results of the previous. +/// +/// The macro first classifies each field based on its type, distinguishing +/// between fixed-size types (like `u8`, `i16`, and arrays) and variable-sized +/// types. It also tracks whether a field has a default value, ensuring that +/// this information is passed along. To handle repetitive or complex patterns, +/// especially when dealing with type information, the macro utilizes the +/// `paste!` macro for duplication and transformation. +/// +/// As it processes each field, the macro recursively calls itself, accumulating +/// metadata and updating the state. This recursive approach is structured into +/// different stages, such as `__builder_type__`, `__builder_value__`, and +/// `__finalize__`, each responsible for handling specific aspects of the +/// parsing process. +/// +/// Once all fields have been processed, the macro enters the final stage, where +/// it reconstructs an enriched `struct`-like data blob using the accumulated +/// metadata. It then passes this enriched `struct` to the `next` macro for +/// further processing. +macro_rules! struct_elaborate { + ( + $next:ident $( ($($next_args:tt)*) )? => + $( #[ $sdoc:meta ] )* + struct $name:ident $(: $super:ident)? { + $( + $( #[ doc = $fdoc:literal ] )* $field:ident : + $ty:tt $(< $($generics:ident),+ >)? + $( = $value:literal)? + ),* + $(,)? + } + ) => { + // paste! is necessary here because it allows us to re-interpret a "ty" + // as an explicit type pattern below. + struct_elaborate!(__builder_type__ + // Pass down a "fixed offset" flag that indicates whether the + // current field is at a fixed offset. This gets reset to + // `no_fixed_offset` when we hit a variable-sized field. + fixed(fixed_offset (0)) + fields($( + [ + // Note that we double the type so we can re-use some output + // patterns in `__builder_type__` + type( $ty $(<$($generics),+>)? )( $ty $(<$($generics),+>)? ), + value($($value)?), + docs($($fdoc)*), + name($field), + ] + )*) + // Accumulator for field data. + accum() + // Save the original struct parts so we can build the remainder of + // the struct at the end. + original($next $( ($($next_args)*) )? => $(#[$sdoc])* struct $name $(: $super)? {})); + }; + + // End of push-down automation - jumps to `__finalize__` + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields() accum($($faccum:tt)*) original($($original:tt)*)) => { + struct_elaborate!(__finalize__ accum($($faccum)*) original($($original)*)); + }; + + // Skip __builder_value__ for 'len' + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(len)(len), value(), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::protocol::meta::Length), size(fixed=fixed), value(auto=auto), $($rest)*] $($frest)*) $($srest)*); + }; + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(len)(len), value($($value:tt)+), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::protocol::meta::Length), size(fixed=fixed), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); + }; + // Pattern match on known fixed-sized types and mark them as `size(fixed=fixed)` + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type([u8; 4])($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + }; + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(u8)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+1)) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + }; + (__builder_type__ fixed($fixed:ident $fixed_expr:expr)fields([type(i16)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+2)) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + }; + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(i32)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + }; + + // Fallback for other types - variable sized + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type($ty:ty)($ty2:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_value__ fixed($fixed=>no_fixed_offset $fixed_expr=>(0)) fields([type($ty), size(variable=variable), $($rest)*] $($frest)*) $($srest)*); + }; + + // Next, mark the presence or absence of a value + (__builder_value__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ + type($ty:ty), size($($size:tt)*), value(), $($rest:tt)* + ] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value(no_value=no_value), $($rest)*] $($frest)*) $($srest)*); + }; + (__builder_value__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ + type($ty:ty), size($($size:tt)*), value($($value:tt)+), $($rest:tt)* + ] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); + }; + + // Next, handle missing docs + (__builder_docs__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ + type($ty:ty), size($($size:tt)*), value($($value:tt)*), docs(), name($field:ident), $($rest:tt)* + ] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value($($value)*), docs(concat!("`", stringify!($field), "` field.")), name($field), $($rest)*] $($frest)*) $($srest)*); + }; + (__builder_docs__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ + type($ty:ty), size($($size:tt)*), value($($value:tt)*), docs($($fdoc:literal)+), $($rest:tt)* + ] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value($($value)*), docs(concat!($($fdoc)+)), $($rest)*] $($frest)*) $($srest)*); + }; + + + // Push down the field to the accumulator + (__builder__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ + type($ty:ty), size($($size:tt)*), value($($value:tt)*), docs($fdoc:expr), name($field:ident), $($rest:tt)* + ] $($frest:tt)*) accum($($faccum:tt)*) original($($original:tt)*)) => { + struct_elaborate!(__builder_type__ fixed($fixed_new $fixed_expr_new) fields($($frest)*) accum( + $($faccum)* + { + name($field), + type($ty), + size($($size)*), + value($($value)*), + docs($fdoc), + fixed($fixed=$fixed, $fixed_expr), + }, + ) original($($original)*)); + }; + + // Write the final struct + (__finalize__ accum($($accum:tt)*) original($next:ident $( ($($next_args:tt)*) )?=> $( #[ $sdoc:meta ] )* struct $name:ident $(: $super:ident)? {})) => { + $next ! ( + $( $($next_args)* , )? + struct $name { + super($($super)?), + docs($($sdoc),*), + fields( + $($accum)* + ), + } + ); + } +} + +macro_rules! protocol { + ($( $( #[ $sdoc:meta ] )* struct $name:ident $(: $super:ident)? { $($struct:tt)+ } )+) => { + $( + paste::paste!( + pub(crate) mod [<$name:lower>] { + #[allow(unused_imports)] + use super::*; + use $crate::protocol::gen::*; + struct_elaborate!(protocol_builder(__struct__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + struct_elaborate!(protocol_builder(__meta__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + struct_elaborate!(protocol_builder(__measure__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + struct_elaborate!(protocol_builder(__builder__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + } + ); + )+ + + pub mod data { + #![allow(unused_imports)] + $( + paste::paste!( + pub use super::[<$name:lower>]::$name; + ); + )+ + } + pub mod meta { + #![allow(unused_imports)] + $( + paste::paste!( + pub use super::[<$name:lower>]::[<$name Meta>] as $name; + ); + )+ + + /// A slice containing the metadata references for all structs in + /// this definition. + #[allow(unused)] + pub const ALL: &'static [&'static dyn $crate::protocol::Meta] = &[ + $( + &$name {} + ),* + ]; + } + pub mod builder { + #![allow(unused_imports)] + $( + paste::paste!( + pub use super::[<$name:lower>]::[<$name Builder>] as $name; + ); + )+ + } + pub mod measure { + #![allow(unused_imports)] + $( + paste::paste!( + pub use super::[<$name:lower>]::[<$name Measure>] as $name; + ); + )+ + } + }; +} + +macro_rules! r#if { + (__is_empty__ [] {$($true:tt)*} else {$($false:tt)*}) => { + $($true)* + }; + (__is_empty__ [$($x:tt)+] {$($true:tt)*} else {$($false:tt)*}) => { + $($false)* + }; + (__has__ [$($x:tt)+] {$($true:tt)*}) => { + $($true)* + }; + (__has__ [] {$($true:tt)*}) => { + }; +} + +macro_rules! protocol_builder { + (__struct__, struct $name:ident { + super($($super:ident)?), + docs($($sdoc:meta),*), + fields($({ + name($field:ident), + type($type:ty), + size($($size:tt)*), + value($(value = ($value:expr))? $(no_value = $no_value:ident)? $(auto = $auto:ident)?), + docs($fdoc:expr), + fixed($fixed:ident=$fixed2:ident, $fixed_expr:expr), + $($rest:tt)* + },)*), + }) => { + paste::paste!( + /// Our struct we are building. + type S<'a> = $name<'a>; + /// The meta-struct for the struct we are building. + type Meta = [<$name Meta>]; + /// The measurement struct (used for `measure`). + type M<'a> = [<$name Measure>]<'a>; + /// The builder struct (used for `to_vec` and other build operations) + type B<'a> = [<$name Builder>]<'a>; + /// The fields ordinal enum. + type F<'a> = [<$name Fields>]; + + $( #[$sdoc] )? + #[doc = concat!("\n\nAvailable fields: \n\n" $( + , " - [`", stringify!($field), "`](Self::", stringify!($field), "()): ", $fdoc, + $( " (value = `", stringify!($value), "`)", )? + "\n\n" + )* )] + pub struct $name<'a> { + /// Our zero-copy buffer. + #[doc(hidden)] + pub(crate) __buf: &'a [u8], + /// The calculated field offsets. + #[doc(hidden)] + __field_offsets: [usize; Meta::FIELD_COUNT + 1] + } + + impl PartialEq for $name<'_> { + fn eq(&self, other: &Self) -> bool { + self.__buf.eq(other.__buf) + } + } + + impl std::fmt::Debug for $name<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut s = f.debug_struct(stringify!($name)); + $( + s.field(stringify!($field), &self.$field()); + )* + s.finish() + } + } + + #[allow(unused)] + impl <'a> S<'a> { + /// Checks the constant values for this struct to determine whether + /// this message matches. + #[inline] + pub const fn is_buffer(buf: &'a [u8]) -> bool { + let mut offset = 0; + + // NOTE! This only works for fixed-sized fields and assumes + // that they all exist before variable-sized fields. + + $( + $(if $crate::protocol::FieldAccess::<$type>::extract(buf.split_at(offset).1) != $value as usize as _ { return false;})? + offset += std::mem::size_of::<$type>(); + )* + + true + } + + $( + pub const fn can_cast(parent: &<$super as $crate::protocol::Enliven>::WithLifetime<'a>) -> bool { + Self::is_buffer(parent.__buf) + } + + pub const fn try_new(parent: &<$super as $crate::protocol::Enliven>::WithLifetime<'a>) -> Option { + if Self::can_cast(parent) { + Some(Self::new(parent.__buf)) + } else { + None + } + } + )? + + /// Creates a new instance of this struct from a given buffer. + #[inline] + pub const fn new(mut buf: &'a [u8]) -> Self { + let mut __field_offsets = [0; Meta::FIELD_COUNT + 1]; + let mut offset = 0; + let mut index = 0; + $( + __field_offsets[index] = offset; + offset += $crate::protocol::FieldAccess::<$type>::size_of_field_at(buf.split_at(offset).1); + index += 1; + )* + __field_offsets[index] = offset; + + Self { + __buf: buf, + __field_offsets, + } + } + + pub fn to_vec(&self) -> Vec { + self.__buf.to_vec() + } + + $( + #[doc = $fdoc] + #[allow(unused)] + #[inline] + pub const fn $field<'s>(&'s self) -> <$type as $crate::protocol::Enliven>::WithLifetime<'a> where 's : 'a { + // Perform a const buffer extraction operation + let offset1 = self.__field_offsets[F::$field as usize]; + let offset2 = self.__field_offsets[F::$field as usize + 1]; + let (_, buf) = self.__buf.split_at(offset1); + let (buf, _) = buf.split_at(offset2 - offset1); + $crate::protocol::FieldAccess::<$type>::extract(buf) + } + )* + } + ); + }; + + (__meta__, struct $name:ident { + super($($super:ident)?), + docs($($sdoc:meta),*), + fields($({ + name($field:ident), + type($type:ty), + size($($size:tt)*), + value($(value = ($value:expr))? $(no_value = $no_value:ident)? $(auto = $auto:ident)?), + docs($fdoc:expr), + fixed($fixed:ident=$fixed2:ident, $fixed_expr:expr), + $($rest:tt)* + },)*), + }) => { + paste::paste!( + $( #[$sdoc] )? + #[allow(unused)] + #[derive(Debug, Default)] + pub struct [<$name Meta>] { + } + + #[allow(unused)] + #[allow(non_camel_case_types)] + #[derive(Eq, PartialEq)] + #[repr(u8)] + enum [<$name Fields>] { + $( + $field, + )* + } + + #[allow(unused)] + impl Meta { + pub const FIELD_COUNT: usize = [$(stringify!($field)),*].len(); + $($(pub const [<$field:upper _VALUE>]: $type = $crate::protocol::FieldAccess::<$type>::constant($value as usize);)?)* + } + + impl $crate::protocol::Meta for Meta { + fn name(&self) -> &'static str { + stringify!($name) + } + fn relations(&self) -> &'static [($crate::protocol::MetaRelation, &'static dyn $crate::protocol::Meta)] { + r#if!(__is_empty__ [$($super)?] { + const RELATIONS: &'static [($crate::protocol::MetaRelation, &'static dyn $crate::protocol::Meta)] = &[ + $( + ($crate::protocol::MetaRelation::Field(stringify!($field)), $crate::protocol::FieldAccess::<$type>::meta()) + ),* + ]; + } else { + const RELATIONS: &'static [($crate::protocol::MetaRelation, &'static dyn $crate::protocol::Meta)] = &[ + ($crate::protocol::MetaRelation::Parent, $crate::protocol::FieldAccess::<$($super)?>::meta()), + $( + ($crate::protocol::MetaRelation::Field(stringify!($field)), $crate::protocol::FieldAccess::<$type>::meta()) + ),* + ]; + }); + RELATIONS + } + } + + $( + protocol_builder!(__meta__, $fixed($fixed_expr) $field $type); + )* + + impl $crate::protocol::Struct for Meta { + type Struct<'a> = S<'a>; + fn new(buf: &[u8]) -> S<'_> { + S::new(buf) + } + } + + impl $crate::protocol::Enliven for Meta { + type WithLifetime<'a> = S<'a>; + type ForMeasure<'a> = M<'a>; + type ForBuilder<'a> = B<'a>; + } + + #[allow(unused)] + impl $crate::protocol::FieldAccess { + #[inline(always)] + pub const fn name() -> &'static str { + stringify!($name) + } + #[inline(always)] + pub const fn meta() -> &'static dyn $crate::protocol::Meta { + &Meta {} + } + #[inline] + pub const fn size_of_field_at(buf: &[u8]) -> usize { + let mut offset = 0; + $( + offset += $crate::protocol::FieldAccess::<$type>::size_of_field_at(buf.split_at(offset).1); + )* + offset + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> $name<'_> { + $name::new(buf) + } + #[inline(always)] + pub const fn measure(measure: &M) -> usize { + measure.measure() + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut $crate::protocol::writer::BufWriter, builder: &B) { + builder.copy_to_buf(buf) + } + #[inline(always)] + pub fn copy_to_buf_ref(buf: &mut $crate::protocol::writer::BufWriter, builder: &B) { + builder.copy_to_buf(buf) + } + } + + $crate::protocol::field_access!{[<$name Meta>]} + $crate::protocol::arrays::array_access!{[<$name Meta>]} + ); + }; + (__meta__, fixed_offset($fixed_expr:expr) $field:ident $crate::protocol::meta::Length) => { + impl $crate::protocol::StructLength for Meta { + fn length_field_of(of: &Self::Struct<'_>) -> usize { + of.$field() + } + fn length_field_offset() -> usize { + $fixed_expr + } + } + }; + (__meta__, $fixed:ident($fixed_expr:expr) $field:ident $crate::protocol::meta::Rest) => { + + }; + (__meta__, $fixed:ident($fixed_expr:expr) $field:ident $any:ty) => { + }; + + (__measure__, struct $name:ident { + super($($super:ident)?), + docs($($sdoc:meta),*), + fields($({ + name($field:ident), + type($type:ty), + size( $( fixed=$fixed_marker:ident )? $( variable=$variable_marker:ident )? ), + value($(value = ($value:expr))? $(no_value = $no_value:ident)? $(auto = $auto:ident)?), + docs($fdoc:expr), + $($rest:tt)* + },)*), + }) => { + paste::paste!( + r#if!(__is_empty__ [$($($variable_marker)?)*] { + $( #[$sdoc] )? + // No variable-sized fields + #[derive(Default, Eq, PartialEq)] + pub struct [<$name Measure>]<'a> { + __no_fields_use_default: std::marker::PhantomData<&'a ()> + } + } else { + $( #[$sdoc] )? + pub struct [<$name Measure>]<'a> { + // Because of how macros may expand in the context of struct + // fields, we need to do a * repeat, then a ? repeat and + // somehow use $variable_marker in the remainder of the + // pattern. + $($( + #[doc = $fdoc] + pub $field: r#if!(__has__ [$variable_marker] {<$type as $crate::protocol::Enliven>::ForMeasure<'a>}), + )?)* + } + }); + + impl M<'_> { + pub const fn measure(&self) -> usize { + let mut size = 0; + $( + r#if!(__has__ [$($variable_marker)?] { size += $crate::protocol::FieldAccess::<$type>::measure(&self.$field); }); + r#if!(__has__ [$($fixed_marker)?] { size += std::mem::size_of::<$type>(); }); + )* + size + } + } + ); + }; + + (__builder__, struct $name:ident { + super($($super:ident)?), + docs($($sdoc:meta),*), + fields($({ + name($field:ident), + type($type:ty), + size($($size:tt)*), + value($(value = ($value:expr))? $(no_value = $no_value:ident)? $(auto = $auto:ident)?), + docs($fdoc:expr), + $($rest:tt)* + },)*), + }) => { + paste::paste!( + r#if!(__is_empty__ [$($($no_value)?)*] { + $( #[$sdoc] )? + // No unfixed-value fields + #[derive(Default, Eq, PartialEq)] + pub struct [<$name Builder>]<'a> { + __no_fields_use_default: std::marker::PhantomData<&'a ()> + } + } else { + $( #[$sdoc] )? + #[derive(Default, Eq, PartialEq)] + pub struct [<$name Builder>]<'a> { + // Because of how macros may expand in the context of struct + // fields, we need to do a * repeat, then a ? repeat and + // somehow use $no_value in the remainder of the pattern. + $($( + #[doc = $fdoc] + pub $field: r#if!(__has__ [$no_value] {<$type as $crate::protocol::Enliven>::ForBuilder<'a>}), + )?)* + } + }); + + impl B<'_> { + #[allow(unused)] + pub fn copy_to_buf(&self, buf: &mut $crate::protocol::writer::BufWriter) { + $( + r#if!(__is_empty__ [$($value)?] { + r#if!(__is_empty__ [$($auto)?] { + $crate::protocol::FieldAccess::<$type>::copy_to_buf(buf, self.$field); + } else { + let auto_offset = buf.size(); + $crate::protocol::FieldAccess::<$type>::copy_to_buf(buf, 0); + }); + } else { + $crate::protocol::FieldAccess::<$type>::copy_to_buf(buf, $($value)? as usize as _); + }); + )* + + $( + r#if!(__has__ [$($auto)?] { + $crate::protocol::FieldAccess::::copy_to_buf_rewind(buf, auto_offset, buf.size() - auto_offset); + }); + )* + + } + + /// Convert this builder into a vector of bytes. This is generally + /// not the most efficient way to perform serialization. + #[allow(unused)] + pub fn to_vec(&self) -> Vec { + let mut vec = Vec::with_capacity(256); + let mut buf = $crate::protocol::writer::BufWriter::new(&mut vec); + self.copy_to_buf(&mut buf); + match buf.finish() { + Ok(size) => { + vec.truncate(size); + vec + }, + Err(size) => { + vec.resize(size, 0); + let mut buf = $crate::protocol::writer::BufWriter::new(&mut vec); + self.copy_to_buf(&mut buf); + let size = buf.finish().unwrap(); + vec.truncate(size); + vec + } + } + } + } + ); + }; +} + +pub(crate) use {protocol, protocol_builder, r#if, struct_elaborate}; + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + mod fixed_only { + protocol!( + struct FixedOnly { + a: u8, + } + ); + } + + mod fixed_only_value { + protocol!(struct FixedOnlyValue { + a: u8 = 1, + }); + } + + mod mixed { + use crate::protocol::meta::ZTString; + protocol!(struct Mixed { + a: u8 = 1, + s: ZTString, + }); + } + + mod docs { + use crate::protocol::meta::ZTString; + protocol!( + /// Docs + struct Docs { + /// Docs + a: u8 = 1, + /// Docs + s: ZTString, + } + ); + } + + mod length { + use crate::protocol::meta::Length; + protocol!( + struct WithLength { + a: u8, + l: len, + } + ); + } + + mod array { + protocol!( + struct StaticArray { + a: u8, + l: [u8; 4], + } + ); + } + + macro_rules! assert_stringify { + (($($struct:tt)*), ($($expected:tt)*)) => { + struct_elaborate!(assert_stringify(__internal__ ($($expected)*)) => $($struct)*); + }; + (__internal__ ($($expected:tt)*), $($struct:tt)*) => { + assert_eq!(stringify!($($struct)*), stringify!($($expected)*)); + }; + } + + #[test] + fn empty_struct() { + assert_stringify!((struct Foo {}), (struct Foo { super (), docs(), fields(), })); + } + + #[test] + fn fixed_size_fields() { + assert_stringify!((struct Foo { + a: u8, + b: u8, + }), (struct Foo + { + super (), + docs(), + fields({ + name(a), type (u8), size(fixed = fixed), value(no_value = no_value), + docs(concat!("`", stringify! (a), "` field.")), + fixed(fixed_offset = fixed_offset, (0)), + }, + { + name(b), type (u8), size(fixed = fixed), value(no_value = no_value), + docs(concat!("`", stringify! (b), "` field.")), + fixed(fixed_offset = fixed_offset, ((0) + 1)), + },), + })); + } + + #[test] + fn mixed_fields() { + assert_stringify!((struct Foo { + a: u8, + l: len, + s: ZTString, + c: i16, + d: [u8; 4], + e: ZTArray, + }), (struct Foo + { + super (), + docs(), + fields({ + name(a), type (u8), size(fixed = fixed), value(no_value = no_value), + docs(concat!("`", stringify! (a), "` field.")), + fixed(fixed_offset = fixed_offset, (0)), + }, + { + name(l), type (crate::protocol::meta::Length), size(fixed = fixed), + value(auto = auto), docs(concat!("`", stringify! (l), "` field.")), + fixed(fixed_offset = fixed_offset, ((0) + 1)), + }, + { + name(s), type (ZTString), size(variable = variable), + value(no_value = no_value), + docs(concat!("`", stringify! (s), "` field.")), + fixed(fixed_offset = fixed_offset, (((0) + 1) + 4)), + }, + { + name(c), type (i16), size(fixed = fixed), value(no_value = no_value), + docs(concat!("`", stringify! (c), "` field.")), + fixed(no_fixed_offset = no_fixed_offset, (0)), + }, + { + name(d), type ([u8; 4]), size(fixed = fixed), + value(no_value = no_value), + docs(concat!("`", stringify! (d), "` field.")), + fixed(no_fixed_offset = no_fixed_offset, ((0) + 2)), + }, + { + name(e), type (ZTArray), size(variable = variable), + value(no_value = no_value), + docs(concat!("`", stringify! (e), "` field.")), + fixed(no_fixed_offset = no_fixed_offset, (((0) + 2) + 4)), + }, + ), + })); + } +} diff --git a/edb/server/pgrust/src/protocol/message_group.rs b/edb/server/pgrust/src/protocol/message_group.rs new file mode 100644 index 000000000000..f1b2c7d56f4d --- /dev/null +++ b/edb/server/pgrust/src/protocol/message_group.rs @@ -0,0 +1,123 @@ +macro_rules! message_group { + ($(#[$doc:meta])* $group:ident : $super:ident = [$($message:ty),*]) => { + paste::paste!( + $(#[$doc])* + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[allow(unused)] + pub enum $group { + $( + #[doc = concat!("Matched [`", stringify!($message), "`]")] + $message + ),* + } + + #[allow(unused)] + pub enum [<$group Builder>]<'a> { + $( + $message(builder::$message<'a>) + ),* + } + + #[allow(unused)] + impl [<$group Builder>]<'_> { + pub fn to_vec(&self) -> Vec { + match self { + $( + Self::$message(message) => message.to_vec(), + )* + } + } + } + + $( + impl <'a> From> for [<$group Builder>]<'a> { + fn from(message: builder::$message<'a>) -> Self { + Self::$message(message) + } + } + )* + + #[allow(unused)] + pub trait [<$group Match>] { + $( + fn [<$message:snake>]<'a>(&mut self) -> Option)> { + // No implementation by default + let mut opt = Some(|_| {}); + opt.take(); + opt + } + )* + // fn unknown(&mut self, message: self::struct_defs::Message::Message) { + // // No implementation by default + // } + } + + #[allow(unused)] + impl $group { + pub fn identify(buf: &[u8]) -> Option { + $( + if <$message as $crate::protocol::Enliven>::WithLifetime::is_buffer(buf) { + return Some(Self::$message); + } + )* + None + } + + pub fn match_message(matcher: &mut impl [<$group Match>], buf: &[u8]) { + $( + if data::$message::is_buffer(buf) { + if let Some(mut f) = matcher.[<$message:snake>]() { + let message = data::$message::new(buf); + f(message); + return; + } + } + )* + } + } + ); + }; +} +pub(crate) use message_group; + +/// Peform a match on a message. +/// +/// ```rust +/// use pgrust::protocol::*; +/// use pgrust::protocol::messages::*; +/// +/// let buf = [b'?', 0, 0, 0, 4]; +/// match_message!(Message::new(&buf), Backend { +/// (BackendKeyData as data) => { +/// todo!(); +/// }, +/// unknown => { +/// eprintln!("Unknown message: {unknown:?}"); +/// } +/// }); +/// ``` +#[doc(hidden)] +#[macro_export] +macro_rules! __match_message { + ($buf:expr, $messages:ty { + $(( $i1:path $(as $i2:ident )?) => $impl:block,)* + $unknown:ident => $unknown_impl:block $(,)? + }) => { + { + let __message = $buf; + $( + if let Some(__tmp) = <$i1>::try_new(&__message) { + $(let $i2 = __tmp;)? + $impl + } else + )* + { + let $unknown = __message; + $unknown_impl + } + } + }; +} + +#[doc(inline)] +pub use __match_message as match_message; diff --git a/edb/server/pgrust/src/protocol/mod.rs b/edb/server/pgrust/src/protocol/mod.rs new file mode 100644 index 000000000000..62ddcb39d9be --- /dev/null +++ b/edb/server/pgrust/src/protocol/mod.rs @@ -0,0 +1,392 @@ +mod arrays; +mod buffer; +mod datatypes; +pub(crate) mod definition; +mod gen; +mod message_group; +mod writer; + +/// Metatypes for the protocol and related arrays/strings. +pub mod meta { + pub use super::arrays::meta::*; + pub use super::datatypes::meta::*; + pub use super::definition::meta::*; +} + +/// Measurement structs. +pub mod measure { + pub use super::definition::measure::*; +} + +/// Builder structs. +pub mod builder { + pub use super::definition::builder::*; +} + +/// Message types collections. +pub mod messages { + pub use super::definition::{Backend, Frontend}; +} + +#[allow(unused)] +pub use arrays::{Array, ArrayIter, ZTArray, ZTArrayIter}; +pub use buffer::StructBuffer; +#[allow(unused)] +pub use datatypes::{Encoded, Rest, ZTString}; +#[allow(unused)] +pub use definition::data::*; +pub use message_group::match_message; + +/// Implemented for all structs. +pub trait Struct { + type Struct<'a>; + fn new(buf: &[u8]) -> Self::Struct<'_>; +} + +/// Implemented for all generated structs that have a [`meta::Length`] field at a fixed offset. +pub trait StructLength: Struct { + fn length_field_of(of: &Self::Struct<'_>) -> usize; + fn length_field_offset() -> usize; + fn length_of_buf(buf: &[u8]) -> Option { + if buf.len() < Self::length_field_offset() + std::mem::size_of::() { + None + } else { + Some( + Self::length_field_offset() + + FieldAccess::::extract( + &buf[Self::length_field_offset() + ..Self::length_field_offset() + std::mem::size_of::()], + ), + ) + } + } +} + +/// For a given metaclass, returns the inflated type, a measurement type and a +/// builder type. +pub trait Enliven { + type WithLifetime<'a>; + type ForMeasure<'a>: 'a; + type ForBuilder<'a>: 'a; +} + +pub trait FixedSize { + const SIZE: usize; +} + +#[derive(Debug, Eq, PartialEq)] +pub enum MetaRelation { + Parent, + Length, + Item, + Field(&'static str), +} + +pub trait Meta { + fn name(&self) -> &'static str { + std::any::type_name::() + } + fn relations(&self) -> &'static [(MetaRelation, &'static dyn Meta)] { + &[] + } + fn field(&self, name: &'static str) -> Option<&'static dyn Meta> { + for (relation, meta) in self.relations() { + if relation == &MetaRelation::Field(name) { + return Some(*meta); + } + } + None + } + fn parent(&self) -> Option<&'static dyn Meta> { + for (relation, meta) in self.relations() { + if relation == &MetaRelation::Parent { + return Some(*meta); + } + } + None + } +} + +impl PartialEq for dyn Meta { + fn eq(&self, other: &T) -> bool { + other.name() == self.name() + } +} + +impl std::fmt::Debug for dyn Meta { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut s = f.debug_struct(self.name()); + for (relation, meta) in self.relations() { + if relation == &MetaRelation::Parent { + s.field(&format!("{relation:?}"), &meta.name()); + } else { + s.field(&format!("{relation:?}"), meta); + } + } + s.finish() + } +} + +/// Delegates to a concrete [`FieldAccess`] but as a non-const trait. This is +/// used for performing extraction in iterators. +pub(crate) trait FieldAccessArray: Enliven { + const META: &'static dyn Meta; + fn size_of_field_at(buf: &[u8]) -> usize; + fn extract(buf: &[u8]) -> ::WithLifetime<'_>; +} + +/// This struct is specialized for each type we want to extract data from. We +/// have to do it this way to work around Rust's lack of const specialization. +pub(crate) struct FieldAccess { + _phantom_data: std::marker::PhantomData, +} + +/// Delegate to the concrete [`FieldAccess`] for each type we want to extract. +macro_rules! field_access { + ($ty:ty) => { + impl $crate::protocol::FieldAccessArray for $ty { + const META: &'static dyn $crate::protocol::Meta = + $crate::protocol::FieldAccess::<$ty>::meta(); + #[inline(always)] + fn size_of_field_at(buf: &[u8]) -> usize { + $crate::protocol::FieldAccess::<$ty>::size_of_field_at(buf) + } + #[inline(always)] + fn extract(buf: &[u8]) -> ::WithLifetime<'_> { + $crate::protocol::FieldAccess::<$ty>::extract(buf) + } + } + }; +} +pub(crate) use field_access; + +#[cfg(test)] +mod tests { + use super::*; + use buffer::StructBuffer; + use definition::builder; + + #[test] + fn test_sasl_response() { + let buf = [b'p', 0, 0, 0, 5, 2]; + assert!(SASLResponse::is_buffer(&buf)); + let message = SASLResponse::new(&buf); + assert_eq!(message.mlen(), 5); + assert_eq!(message.response().len(), 1); + } + + #[test] + fn test_sasl_response_measure() { + let measure = measure::SASLResponse { + response: &[1, 2, 3, 4, 5], + }; + assert_eq!(measure.measure(), 10) + } + + #[test] + fn test_sasl_initial_response() { + let buf = [ + b'p', 0, 0, 0, 0x36, // Mechanism + b'S', b'C', b'R', b'A', b'M', b'-', b'S', b'H', b'A', b'-', b'2', b'5', b'6', 0, + // Data + 0, 0, 0, 32, b'n', b',', b',', b'n', b'=', b',', b'r', b'=', b'p', b'E', b'k', b'P', + b'L', b'Q', b'u', b'2', b'9', b'G', b'E', b'v', b'w', b'N', b'e', b'V', b'J', b't', + b'7', b'2', b'a', b'r', b'Q', b'I', + ]; + + assert!(SASLInitialResponse::is_buffer(&buf)); + let message = SASLInitialResponse::new(&buf); + assert_eq!(message.mlen(), 0x36); + assert_eq!(message.mechanism(), "SCRAM-SHA-256"); + assert_eq!( + message.response().as_ref(), + b"n,,n=,r=pEkPLQu29GEvwNeVJt72arQI" + ); + } + + #[test] + fn test_sasl_initial_response_builder() { + let buf = builder::SASLInitialResponse { + mechanism: "SCRAM-SHA-256", + response: b"n,,n=,r=pEkPLQu29GEvwNeVJt72arQI", + } + .to_vec(); + + let message = SASLInitialResponse::new(&buf); + assert_eq!(message.mlen(), 0x36); + assert_eq!(message.mechanism(), "SCRAM-SHA-256"); + assert_eq!( + message.response().as_ref(), + b"n,,n=,r=pEkPLQu29GEvwNeVJt72arQI" + ); + } + + #[test] + fn test_startup_message() { + let buf = [ + 0, 0, 0, 41, 0, 0x03, 0, 0, 0x75, 0x73, 0x65, 0x72, 0, 0x70, 0x6f, 0x73, 0x74, 0x67, + 0x72, 0x65, 0x73, 0, 0x64, 0x61, 0x74, 0x61, 0x62, 0x61, 0x73, 0x65, 0, 0x70, 0x6f, + 0x73, 0x74, 0x67, 0x72, 0x65, 0x73, 0, 0, + ]; + let message = StartupMessage::new(&buf); + assert_eq!(message.mlen(), buf.len()); + assert_eq!(message.protocol(), 196608); + let arr = message.params(); + let mut vals = vec![]; + for entry in arr { + vals.push(entry.name().to_owned().unwrap()); + vals.push(entry.value().to_owned().unwrap()); + } + assert_eq!(vals, vec!["user", "postgres", "database", "postgres"]); + } + + #[test] + fn test_row_description() { + let buf = [ + b'T', 0, 0, 0, 48, // header + 0, 2, // # of fields + b'f', b'1', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // field 1 + b'f', b'2', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // field 2 + ]; + assert!(RowDescription::is_buffer(&buf)); + let message = RowDescription::new(&buf); + assert_eq!(message.mlen(), buf.len() - 1); + assert_eq!(message.fields().len(), 2); + let mut iter = message.fields().into_iter(); + let f1 = iter.next().unwrap(); + assert_eq!(f1.name(), "f1"); + let f2 = iter.next().unwrap(); + assert_eq!(f2.name(), "f2"); + assert_eq!(None, iter.next()); + } + + #[test] + fn test_row_description_measure() { + let measure = measure::RowDescription { + fields: &[ + measure::RowField { name: "F1" }, + measure::RowField { name: "F2" }, + ], + }; + assert_eq!(49, measure.measure()) + } + + #[test] + fn test_row_description_builder() { + let builder = builder::RowDescription { + fields: &[ + builder::RowField { + name: "F1", + column_attr_number: 1, + ..Default::default() + }, + builder::RowField { + name: "F2", + data_type_oid: 1234, + format_code: 1, + ..Default::default() + }, + ], + }; + + let vec = builder.to_vec(); + assert_eq!(49, vec.len()); + + // Read it back + assert!(RowDescription::is_buffer(&vec)); + let message = RowDescription::new(&vec); + assert_eq!(message.fields().len(), 2); + let mut iter = message.fields().into_iter(); + let f1 = iter.next().unwrap(); + assert_eq!(f1.name(), "F1"); + assert_eq!(f1.column_attr_number(), 1); + let f2 = iter.next().unwrap(); + assert_eq!(f2.name(), "F2"); + assert_eq!(f2.data_type_oid(), 1234); + assert_eq!(f2.format_code(), 1); + assert_eq!(None, iter.next()); + } + + #[test] + fn test_message_polymorphism_sync() { + let sync = builder::Sync::default(); + let buf = sync.to_vec(); + assert_eq!(buf.len(), 5); + // Read it as a Message + let message = Message::new(&buf); + assert_eq!(message.mlen(), 4); + assert_eq!(message.mtype(), b'S'); + assert_eq!(message.data(), &[]); + // And also a Sync + assert!(Sync::is_buffer(&buf)); + let message = Sync::new(&buf); + assert_eq!(message.mlen(), 4); + assert_eq!(message.mtype(), b'S'); + } + + #[test] + fn test_message_polymorphism_rest() { + let auth = builder::AuthenticationGSSContinue { + data: &[1, 2, 3, 4, 5], + }; + let buf = auth.to_vec(); + assert_eq!(14, buf.len()); + // Read it as a Message + assert!(Message::is_buffer(&buf)); + let message = Message::new(&buf); + assert_eq!(message.mlen(), 13); + assert_eq!(message.mtype(), b'R'); + assert_eq!(message.data(), &[0, 0, 0, 8, 1, 2, 3, 4, 5]); + // And also a AuthenticationGSSContinue + assert!(AuthenticationGSSContinue::is_buffer(&buf)); + let message = AuthenticationGSSContinue::new(&buf); + assert_eq!(message.mlen(), 13); + assert_eq!(message.mtype(), b'R'); + assert_eq!(message.data(), &[1, 2, 3, 4, 5]); + } + + #[test] + fn test_query_messages() { + let data: Vec = vec![ + 0x54, 0x00, 0x00, 0x00, 0x21, 0x00, 0x01, 0x3f, b'c', b'o', b'l', b'u', b'm', b'n', + 0x3f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, + 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x01, b'1', b'C', 0x00, 0x00, 0x00, 0x0d, b'S', b'E', b'L', b'E', b'C', + b'T', b' ', b'1', 0x00, 0x5a, 0x00, 0x00, 0x00, 0x05, b'I', + ]; + + let mut buffer = StructBuffer::::default(); + buffer.push(&data, |message| { + match_message!(message, Backend { + (RowDescription as row) => { + assert_eq!(row.fields().len(), 1); + let field = row.fields().into_iter().next().unwrap(); + assert_eq!(field.name(), "?column?"); + assert_eq!(field.data_type_oid(), 23); + assert_eq!(field.format_code(), 0); + }, + (DataRow as row) => { + assert_eq!(row.values().len(), 1); + assert_eq!(row.values().into_iter().next().unwrap(), "1"); + }, + (CommandComplete as complete) => { + assert_eq!(complete.tag(), "SELECT 1"); + }, + (ReadyForQuery as ready) => { + assert_eq!(ready.status(), b'I'); + }, + unknown => { + panic!("Unknown message type: {:?}", unknown); + } + }); + }); + } + + #[test] + fn test_encode_data_row() { + builder::DataRow { + values: &[Encoded::Value(b"1")], + } + .to_vec(); + } +} diff --git a/edb/server/pgrust/src/protocol/writer.rs b/edb/server/pgrust/src/protocol/writer.rs new file mode 100644 index 000000000000..eafb46fbd3af --- /dev/null +++ b/edb/server/pgrust/src/protocol/writer.rs @@ -0,0 +1,97 @@ +#[derive(Debug)] +pub struct BufWriter<'a> { + buf: &'a mut [u8], + size: usize, + error: bool, +} + +impl<'a> BufWriter<'a> { + #[inline(always)] + pub fn new(buf: &'a mut [u8]) -> Self { + Self { + buf, + size: 0, + error: false, + } + } + + #[inline] + pub fn test(&mut self, size: usize) -> bool { + if self.buf.len() < size { + self.size += size; + self.error = true; + false + } else { + true + } + } + + #[inline] + pub fn size(&self) -> usize { + self.size + } + + #[inline] + pub fn write_rewind(&mut self, offset: usize, buf: &[u8]) { + if self.error { + return; + } + self.buf[offset..offset + buf.len()].copy_from_slice(buf); + } + + #[inline] + pub fn write(&mut self, buf: &[u8]) { + let len = buf.len(); + self.size += len; + if self.error { + return; + } + if self.buf.len() < len { + self.error = true; + return; + } + self.buf[self.size - len..self.size].copy_from_slice(buf); + } + + #[inline] + pub fn write_u8(&mut self, value: u8) { + self.size += 1; + if self.error { + return; + } + if self.buf.is_empty() { + self.error = true; + return; + } + self.buf[self.size - 1] = value; + } + + pub const fn finish(self) -> Result { + if self.error { + Err(self.size) + } else { + Ok(self.size) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_buf_writer() { + let mut buf = [0u8; 10]; + let mut writer = BufWriter::new(&mut buf); + writer.write(b"hello"); + assert_eq!(writer.size(), 5); + } + + #[test] + fn test_buf_writer_too_large() { + let mut buf = [0u8; 10]; + let mut writer = BufWriter::new(&mut buf); + writer.write(b"hello world"); + assert_eq!(writer.size(), 11); + assert!(writer.error); + } +} diff --git a/edb/server/pgrust/src/python.rs b/edb/server/pgrust/src/python.rs index 9fda2e9071b3..a798f43c8aa1 100644 --- a/edb/server/pgrust/src/python.rs +++ b/edb/server/pgrust/src/python.rs @@ -1,13 +1,53 @@ -use std::path::Path; - -use crate::conn_string::{self, EnvVar}; +use crate::connection::params::ConnectionParameters; +use crate::connection::state_machine::{ + ConnectionDrive, ConnectionState, ConnectionStateSend, ConnectionStateUpdate, +}; +use crate::connection::{ + ConnectionError, ConnectionSslRequirement, Credentials, RawConnectionParameters, SslMode, +}; +use crate::protocol::{meta, ErrorResponse, StructBuffer}; +use crate::{ + connection::{ + dsn::{parse_postgres_dsn, EnvVar}, + params::Ssl, + state_machine::ConnectionStateType, + HostType, + }, + protocol::SSLResponse, +}; +use pyo3::exceptions::PyRuntimeError; +use pyo3::types::{PyBytes, PyDict}; +use pyo3::{ + buffer::PyBuffer, + prelude::*, + types::{PyMemoryView, PyNone}, +}; use pyo3::{ exceptions::PyException, - pyfunction, pymodule, + pymodule, types::{PyAnyMethods, PyByteArray, PyModule, PyModuleMethods}, - wrap_pyfunction, Bound, PyAny, PyResult, Python, + Bound, PyAny, PyResult, Python, }; use serde_pickle::SerOptions; +use std::collections::HashMap; +use std::path::Path; + +#[derive(Clone, Copy, PartialEq, Eq)] +#[pyclass(eq, eq_int)] +pub enum SSLMode { + Disable, + Allow, + Prefer, + Require, + VerifyCa, + VerifyFull, +} + +impl From for PyErr { + fn from(err: ConnectionError) -> PyErr { + PyRuntimeError::new_err(err.to_string()) + } +} impl EnvVar for (String, Bound<'_, PyAny>) { fn read(&self, name: &'static str) -> Option> { @@ -21,32 +61,365 @@ impl EnvVar for (String, Bound<'_, PyAny>) { } } -#[pyfunction] -fn parse_dsn(py: Python, username: String, home_dir: String, s: String) -> PyResult> { - let pickle = py.import_bound("pickle")?; - let loads = pickle.getattr("loads")?; - let os = py.import_bound("os")?; - let environ = os.getattr("environ")?; - match conn_string::parse_postgres_url(&s, (username, environ)) { - Ok(mut res) => { - if let Some(warning) = - res.password - .resolve(Path::new(&home_dir), &res.hosts, &res.database, &res.user)? - { - let warnings = py.import_bound("warnings")?; - warnings.call_method1("warn", (warning.to_string(),))?; +#[pyclass] +struct PyConnectionParams { + inner: RawConnectionParameters<'static>, +} + +#[pymethods] +impl PyConnectionParams { + #[new] + #[pyo3(signature = (dsn=None))] + fn new(py: Python, dsn: Option) -> PyResult { + if let Some(dsn_str) = dsn { + match parse_postgres_dsn(&dsn_str) { + Ok(params) => Ok(PyConnectionParams { + inner: params.to_static(), + }), + Err(err) => Err(PyException::new_err(err.to_string())), } - let paths = res.ssl.resolve(Path::new(&home_dir))?; - // Use serde_pickle to get a python-compatible representation of the result - let vec = serde_pickle::to_vec(&(res, paths), SerOptions::new()).unwrap(); - loads.call1((PyByteArray::new_bound(py, &vec),)) + } else { + Ok(PyConnectionParams { + inner: RawConnectionParameters::default(), + }) } - Err(err) => Err(PyException::new_err(err.to_string())), + } + + #[getter] + pub fn hosts(&self) -> Vec<(String, String, u16)> { + self.inner + .hosts() + .unwrap_or_default() + .iter() + .map(|host| match &host.0 { + HostType::Hostname(name) => ("tcp".to_string(), name.clone(), host.1), + HostType::IP(ip, _) => ("tcp".to_string(), ip.to_string(), host.1), + HostType::Path(path) => ("unix".to_string(), path.clone(), host.1), + HostType::Abstract(path) => ("unix".to_string(), path.clone(), host.1), + }) + .collect() + } + + #[getter] + pub fn keys(&self) -> Vec<&str> { + RawConnectionParameters::field_names() + } + + pub fn to_dict(&self) -> HashMap { + self.inner.clone().into() + } + + pub fn update_server_settings(&mut self, py: Python, key: &str, value: &str) -> PyResult<()> { + self.inner + .server_settings + .get_or_insert_with(HashMap::new) + .insert(key.to_string().into(), value.to_string().into()); + Ok(()) + } + + pub fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } + + pub fn resolve(&self, py: Python, username: String, home_dir: String) -> PyResult { + let os = py.import_bound("os")?; + let environ = os.getattr("environ")?; + + let mut params = self.inner.clone(); + params + .apply_env((username.clone(), environ)) + .map_err(|err| PyException::new_err(err.to_string()))?; + let mut params = ConnectionParameters::try_from(params) + .map_err(|err| PyException::new_err(err.to_string()))?; + if let Some(warning) = params.password.resolve( + Path::new(&home_dir), + ¶ms.hosts, + ¶ms.database, + ¶ms.user, + )? { + let warnings = py.import_bound("warnings")?; + warnings.call_method1("warn", (warning.to_string(),))?; + } + + params + .ssl + .resolve(Path::new(&home_dir)) + .map_err(|err| PyException::new_err(err.to_string()))?; + Ok(Self { + inner: params.into(), + }) + } + + pub fn to_dsn(&self) -> String { + self.inner.to_url() + } + + fn __repr__(&self) -> String { + let field_names = RawConnectionParameters::field_names(); + let mut repr = "'); + repr + } + + pub fn __getitem__(&self, py: Python, name: &str) -> Py { + self.inner.get_by_name(name).to_object(py) + } + + pub fn __setitem__(&mut self, py: Python, name: &str, value: &str) -> PyResult<()> { + self.inner + .set_by_name(name, value.to_string().into()) + .map_err(|e| PyException::new_err(e.to_string()))?; + Ok(()) } } #[pymodule] -fn _pg_rust(_py: Python, m: &Bound) -> PyResult<()> { - m.add_function(wrap_pyfunction!(parse_dsn, m)?)?; +pub fn _pg_rust(_py: Python, m: &Bound) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; Ok(()) } + +#[pyclass] +struct PyConnectionState { + inner: ConnectionState, + parsed_dsn: Py, + update: PyConnectionStateUpdate, + message_buffer: StructBuffer, +} + +#[pymethods] +impl PyConnectionState { + #[new] + fn new( + py: Python, + dsn: &PyConnectionParams, + username: String, + home_dir: String, + ) -> PyResult { + let os = py.import_bound("os")?; + let environ = os.getattr("environ")?; + + let mut params = dsn.inner.clone(); + params + .apply_env((username.clone(), environ)) + .map_err(|err| PyException::new_err(err.to_string()))?; + let mut params = ConnectionParameters::try_from(params) + .map_err(|err| PyException::new_err(err.to_string()))?; + if let Some(warning) = params.password.resolve( + Path::new(&home_dir), + ¶ms.hosts, + ¶ms.database, + ¶ms.user, + )? { + let warnings = py.import_bound("warnings")?; + warnings.call_method1("warn", (warning.to_string(),))?; + } + + params + .ssl + .resolve(Path::new(&home_dir)) + .map_err(|err| PyException::new_err(err.to_string()))?; + let credentials = Credentials { + username: params.user.clone(), + password: params.password.password().unwrap_or_default().to_string(), + database: params.database.clone(), + server_settings: params.server_settings.clone(), + }; + let ssl_mode = match params.ssl { + Ssl::Disable => ConnectionSslRequirement::Disable, + Ssl::Enable(SslMode::Allow | SslMode::Prefer, ..) => ConnectionSslRequirement::Optional, + _ => ConnectionSslRequirement::Required, + }; + let params = params.into(); + Ok(PyConnectionState { + inner: ConnectionState::new(credentials, ssl_mode), + parsed_dsn: Py::new(py, PyConnectionParams { inner: params })?, + update: PyConnectionStateUpdate { + py_update: PyNone::get_bound(py).to_object(py), + }, + message_buffer: Default::default(), + }) + } + + #[setter] + fn update(&mut self, py: Python, update: &Bound) { + self.update.py_update = update.to_object(py); + } + + fn is_ready(&self) -> bool { + self.inner.is_ready() + } + + fn read_ssl_response(&self) -> bool { + self.inner.read_ssl_response() + } + + fn drive_initial(&mut self) -> PyResult<()> { + self.inner + .drive(ConnectionDrive::Initial, &mut self.update)?; + Ok(()) + } + + fn drive_message(&mut self, py: Python, data: &Bound) -> PyResult<()> { + let buffer = PyBuffer::::get_bound(data)?; + if self.inner.read_ssl_response() { + // SSL responses are always one character + let response = [buffer.as_slice(py).unwrap().get(0).unwrap().get()]; + let response = SSLResponse::new(&response); + self.inner + .drive(ConnectionDrive::SslResponse(response), &mut self.update)?; + } else { + with_python_buffer(py, buffer, |buf| { + self.message_buffer.push_fallible(buf, |message| { + self.inner + .drive(ConnectionDrive::Message(message), &mut self.update) + }) + })?; + } + Ok(()) + } + + fn drive_ssl_ready(&mut self) -> PyResult<()> { + self.inner + .drive(ConnectionDrive::SslReady, &mut self.update)?; + Ok(()) + } + + #[getter] + fn config(&self, py: Python) -> PyResult> { + Ok(self.parsed_dsn.clone_ref(py)) + } +} + +/// Attempt to stack-copy the data from a `PyBuffer`. +#[inline(always)] +fn with_python_buffer(py: Python, data: PyBuffer, mut f: impl FnMut(&[u8]) -> T) -> T { + let len = data.item_count(); + if len <= 128 { + let mut slice = [0; 128]; + data.copy_to_slice(py, &mut slice[..len]).unwrap(); + f(&slice[..len]) + } else if len <= 1024 { + let mut slice = [0; 1024]; + data.copy_to_slice(py, &mut slice[..len]).unwrap(); + f(&slice[..len]) + } else { + f(&data.to_vec(py).unwrap()) + } +} + +struct PyConnectionStateUpdate { + py_update: Py, +} + +impl ConnectionStateSend for PyConnectionStateUpdate { + fn send_initial( + &mut self, + message: crate::protocol::definition::InitialBuilder, + ) -> Result<(), std::io::Error> { + Python::with_gil(|py| { + let bytes = PyByteArray::new_bound(py, &message.to_vec()); + if let Err(e) = self.py_update.call_method1(py, "send", (bytes,)) { + eprintln!("Error in send_initial: {:?}", e); + e.print(py); + } + }); + Ok(()) + } + + fn send( + &mut self, + message: crate::protocol::definition::FrontendBuilder, + ) -> Result<(), std::io::Error> { + Python::with_gil(|py| { + let bytes = PyBytes::new_bound(py, &message.to_vec()); + if let Err(e) = self.py_update.call_method1(py, "send", (bytes,)) { + eprintln!("Error in send: {:?}", e); + e.print(py); + } + }); + Ok(()) + } + + fn upgrade(&mut self) -> Result<(), std::io::Error> { + Python::with_gil(|py| { + if let Err(e) = self.py_update.call_method0(py, "upgrade") { + eprintln!("Error in upgrade: {:?}", e); + e.print(py); + } + }); + Ok(()) + } +} + +impl ConnectionStateUpdate for PyConnectionStateUpdate { + fn parameter(&mut self, name: &str, value: &str) { + Python::with_gil(|py| { + if let Err(e) = self.py_update.call_method1(py, "parameter", (name, value)) { + eprintln!("Error in parameter: {:?}", e); + e.print(py); + } + }); + } + + fn cancellation_key(&mut self, pid: i32, key: i32) { + Python::with_gil(|py| { + if let Err(e) = self + .py_update + .call_method1(py, "cancellation_key", (pid, key)) + { + eprintln!("Error in cancellation_key: {:?}", e); + e.print(py); + } + }); + } + + fn state_changed(&mut self, state: ConnectionStateType) { + Python::with_gil(|py| { + if let Err(e) = self + .py_update + .call_method1(py, "state_changed", (state as u8,)) + { + eprintln!("Error in state_changed: {:?}", e); + e.print(py); + } + }); + } + + fn auth(&mut self, auth: crate::connection::Authentication) { + Python::with_gil(|py| { + if let Err(e) = self.py_update.call_method1(py, "auth", (auth as u8,)) { + eprintln!("Error in auth: {:?}", e); + e.print(py); + } + }); + } + + fn server_error(&mut self, error: &ErrorResponse) { + Python::with_gil(|py| { + let mut fields = vec![]; + for field in error.fields() { + let etype = field.etype() as char; + let message = field.value().to_string_lossy().to_string(); + fields.push((etype, message)); + } + if let Err(e) = self.py_update.call_method1(py, "server_error", (fields,)) { + eprintln!("Error in server_error: {:?}", e); + e.print(py); + } + }); + } +} diff --git a/edb/server/pgrust/tests/edgedb_test_cases.rs b/edb/server/pgrust/tests/edgedb_test_cases.rs new file mode 100644 index 000000000000..6f2b0c6a0723 --- /dev/null +++ b/edb/server/pgrust/tests/edgedb_test_cases.rs @@ -0,0 +1,285 @@ +mod test_util; + +test_case!(all_env_default_ssl, "postgresql://host:123/testdb", env={ + "PGUSER": "user", + "PGDATABASE": "testdb", + "PGPASSWORD": "passw", + "PGHOST": "host", + "PGPORT": "123", + "PGCONNECT_TIMEOUT": "8" +}, output={ + "user": "user", + "password": "passw", + "dbname": "testdb", + "host": "host", + "port": "123", + "sslmode": "prefer", + "connect_timeout": "8" +}); + +test_case!(dsn_override_env, "postgres://user2:passw2@host2:456/db2?connect_timeout=6", env={ + "PGUSER": "user", + "PGDATABASE": "testdb", + "PGPASSWORD": "passw", + "PGHOST": "host", + "PGPORT": "123", + "PGCONNECT_TIMEOUT": "8" +}, output={ + "user": "user2", + "password": "passw2", + "dbname": "db2", + "host": "host2", + "port": "456", + "connect_timeout": "6" +}); + +test_case!(dsn_override_env_ssl, "postgres://user2:passw2@host2:456/db2?sslmode=disable", env={ + "PGUSER": "user", + "PGDATABASE": "testdb", + "PGPASSWORD": "passw", + "PGHOST": "host", + "PGPORT": "123", + "PGSSLMODE": "allow" +}, output={ + "user": "user2", + "password": "passw2", + "dbname": "db2", + "host": "host2", + "port": "456", + "sslmode": "disable", +}); + +test_case!(dsn_overrides_env_partially, "postgres://user3:123123@localhost:5555/abcdef", env={ + "PGUSER": "user", + "PGDATABASE": "testdb", + "PGPASSWORD": "passw", + "PGHOST": "host", + "PGPORT": "123", + "PGSSLMODE": "allow" +}, output={ + "user": "user3", + "password": "123123", + "dbname": "abcdef", + "host": "localhost", + "port": "5555", + "sslmode": "allow" +}); + +test_case!(dsn_override_env_ssl_prefer, "postgres://user2:passw2@host2:456/db2?sslmode=disable", env={ + "PGUSER": "user", + "PGDATABASE": "testdb", + "PGPASSWORD": "passw", + "PGHOST": "host", + "PGPORT": "123", + "PGSSLMODE": "prefer" +}, output={ + "user": "user2", + "password": "passw2", + "dbname": "db2", + "host": "host2", + "port": "456", + "sslmode": "disable", +}); + +test_case!(dsn_overrides_env_partially_ssl_prefer, "postgres://user3:123123@localhost:5555/abcdef", env={ + "PGUSER": "user", + "PGDATABASE": "testdb", + "PGPASSWORD": "passw", + "PGHOST": "host", + "PGPORT": "123", + "PGSSLMODE": "prefer" +}, output={ + "user": "user3", + "password": "123123", + "dbname": "abcdef", + "host": "localhost", + "port": "5555", + "sslmode": "prefer" +}); + +test_case!(dsn_only, "postgres://user3:123123@localhost:5555/abcdef", output={ + "user": "user3", + "password": "123123", + "dbname": "abcdef", + "host": "localhost", + "port": "5555" +}); + +test_case!(dsn_only_multi_host, "postgresql://user@host1,host2/db", output={ + "user": "user", + "dbname": "db", + "host": "host1,host2", + "port": "5432,5432" +}); + +test_case!(dsn_only_multi_host_and_port, "postgresql://user@host1:1111,host2:2222/db", output={ + "user": "user", + "dbname": "db", + "host": "host1,host2", + "port": "1111,2222" +}); + +test_case!(params_multi_host_dsn_env_mix, "postgresql://host1,host2/db", env={ + "PGUSER": "foo" +}, output={ + "user": "foo", + "dbname": "db", + "host": "host1,host2", + "port": "5432,5432" +}); + +test_case!(dsn_settings_override_and_ssl, "postgresql://me:ask@127.0.0.1:888/db?param=sss¶m=123&host=testhost&user=testuser&port=2222&dbname=testdb&sslmode=require", output={ + "user": "testuser", + "password": "ask", + "dbname": "testdb", + "host": "testhost", + "port": "2222", + "sslmode": "require", + "param": "123" +}, expect_libpq_mismatch="Extra params are unsupported"); + +test_case!(multiple_settings, "postgresql://me:ask@127.0.0.1:888/db?param=sss¶m=123&host=testhost&user=testuser&port=2222&dbname=testdb&sslmode=verify_full&aa=bb", output={ + "user": "testuser", + "password": "ask", + "dbname": "testdb", + "host": "testhost", + "port": "2222", + "sslmode": "verify-full", + "aa": "bb", + "param": "123" +}, expect_libpq_mismatch="Extra params are unsupported"); + +test_case!(dsn_only_unix, "postgresql:///dbname?host=/unix_sock/test&user=spam", output={ + "user": "spam", + "dbname": "dbname", + "host": "/unix_sock/test", + "port": "5432" +}); + +test_case!(dsn_only_quoted, "postgresql://us%40r:p%40ss@h%40st1,h%40st2:543%33/d%62", output={ + "user": "us@r", + "password": "p@ss", + "dbname": "db", + "host": "h@st1,h@st2", + "port": "5432,5433" +}); + +test_case!(dsn_only_unquoted_host, "postgresql://user:p@ss@host/db", output={ + "user": "user", + "password": "p", + "dbname": "db", + "host": "ss@host", + "port": "5432" +}); + +test_case!(dsn_only_quoted_params, "postgresql:///d%62?user=us%40r&host=h%40st&port=543%33", output={ + "user": "us@r", + "dbname": "db", + "host": "h@st", + "port": "5433" +}); + +test_case!(dsn_ipv6_multi_host, "postgresql://user@[2001:db8::1234%25eth0],[::1]/db", output={ + "user": "user", + "dbname": "db", + "host": "2001:db8::1234%eth0,::1", + "port": "5432,5432" +}); + +test_case!(dsn_ipv6_multi_host_port, "postgresql://user@[2001:db8::1234]:1111,[::1]:2222/db", output={ + "user": "user", + "dbname": "db", + "host": "2001:db8::1234,::1", + "port": "1111,2222" +}); + +test_case!(dsn_ipv6_multi_host_query_part, "postgresql:///db?user=user&host=2001:db8::1234,::1", output={ + "user": "user", + "dbname": "db", + "host": "2001:db8::1234,::1", + "port": "5432,5432" +}); + +test_case!( + dsn_only_illegal_protocol, + "pq:///dbname?host=/unix_sock/test&user=spam", + error = "Invalid DSN.*" +); + +test_case!(env_ports_mismatch_dsn_multi_hosts, "postgresql://host1,host2,host3/db", env={ "PGPORT": "111,222" }, error="Unexpected number of ports.*", expect_libpq_mismatch="Port count check doesn't happen in parse"); + +test_case!(dsn_only_quoted_unix_host_port_in_params, "postgres://user@?port=56226&host=%2Ftmp", output={ + "user": "user", + "dbname": "user", + "host": "/tmp", + "port": "56226", + "sslmode": "disable", +}); + +test_case!(dsn_only_cloudsql, "postgres:///db?host=/cloudsql/project:region:instance-name&user=spam", output={ + "user": "spam", + "dbname": "db", + "host": "/cloudsql/project:region:instance-name", + "port": "5432" +}); + +test_case!(connect_timeout_neg8, "postgres://spam@127.0.0.1:5432/postgres?connect_timeout=-8", output={ + "user": "spam", + "dbname": "postgres", + "host": "127.0.0.1", + "port": "5432" +}); + +test_case!(connect_timeout_neg1, "postgres://spam@127.0.0.1:5432/postgres?connect_timeout=-1", output={ + "user": "spam", + "dbname": "postgres", + "host": "127.0.0.1", + "port": "5432" +}); + +test_case!(connect_timeout_0, "postgres://spam@127.0.0.1:5432/postgres?connect_timeout=0", output={ + "user": "spam", + "dbname": "postgres", + "host": "127.0.0.1", + "port": "5432" +}); + +test_case!(connect_timeout_1, "postgres://spam@127.0.0.1:5432/postgres?connect_timeout=1", output={ + "user": "spam", + "dbname": "postgres", + "host": "127.0.0.1", + "port": "5432", + "connect_timeout": "2" +}); + +test_case!(connect_timeout_2, "postgres://spam@127.0.0.1:5432/postgres?connect_timeout=2", output={ + "user": "spam", + "dbname": "postgres", + "host": "127.0.0.1", + "port": "5432", + "connect_timeout": "2" +}); + +test_case!(connect_timeout_3, "postgres://spam@127.0.0.1:5432/postgres?connect_timeout=3", output={ + "user": "spam", + "dbname": "postgres", + "host": "127.0.0.1", + "port": "5432", + "connect_timeout": "3" +}); + +// We intentially don't pass these tests + +test_case!(dsn_combines_env_multi_host, "postgresql:///db", env={ + "PGHOST": "host1:1111,host2:2222", + "PGUSER": "foo" +}, error="", expect_libpq_mismatch="libpq parses hostnames with colons"); + +test_case!(dsn_only_cloudsql_unix_and_tcp, "postgres:///db?host=127.0.0.1:5432,/cloudsql/project:region:instance-name,localhost:5433&user=spam", error="", expect_libpq_mismatch="libpq parses hostnames with colons"); + +test_case!( + dsn_multi_host_combines_env, + "postgresql:///db?host=host1:1111,host2:2222", + error = "", + expect_libpq_mismatch = "libpq parses hostnames with colons" +); diff --git a/edb/server/pgrust/tests/hardcore_host_tests_cases.rs b/edb/server/pgrust/tests/hardcore_host_tests_cases.rs new file mode 100644 index 000000000000..7c8a186a2a8a --- /dev/null +++ b/edb/server/pgrust/tests/hardcore_host_tests_cases.rs @@ -0,0 +1,168 @@ +mod test_util; + +test_case!(host_1, "postgres://host:1", output = {"host": "host", "port": "1"}, no_env = no_env); + +test_case!( + host_2, + "postgres://host:1?host=host2", + output = {"host": "host2", "port": "1"}, + no_env = no_env +); + +test_case!( + host_3, + "postgres://host:1,host2:", + output = {"port": "1,", "host": "host,host2"}, + no_env = no_env +); + +test_case!( + host_4, + "postgres://host:1,host2:,host3,host4:4", + output = {"host": "host,host2,host3,host4", "port": "1,,,4"}, + no_env = no_env +); + +test_case!( + host_5, + "postgres://host:1?port=2,3", + output = {"port": "2,3", "host": "host"}, + no_env = no_env +); + +test_case!(host_6, "postgres://host,host2:2", output = {"host": "host,host2", "port": ",2"}, no_env = no_env); + +test_case!( + host_ipv6, + "postgres://?host=::1", + output = {"host": "::1"}, + no_env = no_env +); + +test_case!(port_only_1, "postgres://:1", output = {"port": "1"}, no_env = no_env); + +test_case!(port_only_2, "postgres://:1,:2", output = {"host": ",", "port": "1,2"}, no_env = no_env); + +test_case!(port_host_mix, "postgres://:1,host2:2,:3", output = {"port": "1,2,3", "host": ",host2,"}, no_env = no_env); + +test_case!(db_override_1, "postgres:///db?dbname=db2", output={ + "dbname": "db2", +}, no_env=no_env); + +test_case!(db_override_2, "postgres:///?dbname=db2", output={ + "dbname": "db2", +}, no_env=no_env); + +test_case!(db_override_3, "postgres://?dbname=db3", output={ + "dbname": "db3", +}, no_env=no_env); + +test_case!(empty_host, "postgres://user@/?host=,", output = { + "host": "/var/run/postgresql,/var/run/postgresql", + "user": "user", + "dbname": "user", + "port": "5432,5432", +}); + +test_case!(empty_param, "postgres://user@old_host:1234?host=&port=", output={ + "port": "5432", + "host": "/var/run/postgresql", + "user": "user", + "dbname": "user", +}); + +test_case!( + hosts_in_host_param, + "postgres://user@/dbname?host=[::1]", + error = "Invalid DSN.*", + expect_libpq_mismatch = "libpq allows for these invalid hostnames" +); + +test_case!( + non_ipv6_in_brackets, + "postgres://user@[localhost]/dbname", + output={ + "user": "user", + "host": "localhost", + "port": "5432", + "dbname": "dbname", + } +); + +test_case!( + path_in_host, + "postgres://user@%2ffoo/dbname", + output={ + "user": "user", + "host": "/foo", + "port": "5432", + "dbname": "dbname", + } +); + +test_case!( + path_in_host_2, + "postgres://user@[/foo]/dbname", + output={ + "user": "user", + "host": "/foo", + "port": "5432", + "dbname": "dbname", + } +); + +test_case!( + path_in_host_3, + "postgres://user@[/foo],[/bar]/dbname", + output={ + "user": "user", + "host": "/foo,/bar", + "port": "5432,5432", + "dbname": "dbname", + } +); + +test_case!( + only_one_part_user, + "postgres://%E3%83%A6%E3%83%BC%E3%82%B6%E3%83%BC%E5%90%8D@", + output={ + "user": "ユーザー名" + }, + no_env=no_env +); + +test_case!( + only_one_part_pass, + "postgres://:%E3%83%91%E3%82%B9%E3%83%AF%E3%83%BC%E3%83%89@", + output={ + "password": "パスワード" + }, + no_env=no_env +); + +test_case!( + only_one_part_port, + "postgres://:1234", + output={ + "port": "1234" + }, + no_env=no_env +); + +test_case!( + only_one_part_hostname, + "postgres://%E3%83%9B%E3%82%B9%E3%83%88%E5%90%8D", + output={ + "host": "ホスト名" + }, + no_env=no_env +); + +test_case!( + only_one_part_database, + "postgres:///%E3%83%87%E3%83%BC%E3%82%BF%E3%83%99%E3%83%BC%E3%82%B9", + output={ + "dbname": "データベース" + }, + no_env=no_env +); diff --git a/edb/server/pgrust/tests/libpq_test_cases.rs b/edb/server/pgrust/tests/libpq_test_cases.rs new file mode 100644 index 000000000000..cf5cebbce767 --- /dev/null +++ b/edb/server/pgrust/tests/libpq_test_cases.rs @@ -0,0 +1,329 @@ +// Copyright (c) 2021-2023, PostgreSQL Global Development Group + +// These testcases were extracted from 001_uri.pl. + +mod test_util; + +test_case!(full_uri, "postgresql://uri-user:secret@host:12345/db", output={ + "user": "uri-user", + "password": "secret", + "dbname": "db", + "host": "host", + "port": "12345" +}, no_env=no_env); + +test_case!(user_host_port_db, "postgresql://uri-user@host:12345/db", output={ + "user": "uri-user", + "dbname": "db", + "host": "host", + "port": "12345" +}, no_env=no_env); + +test_case!(user_host_db, "postgresql://uri-user@host/db", output={ + "user": "uri-user", + "dbname": "db", + "host": "host" +}, no_env=no_env); + +test_case!(host_port_db, "postgresql://host:12345/db", output={ + "dbname": "db", + "host": "host", + "port": "12345" +}, no_env=no_env); + +test_case!(host_db, "postgresql://host/db", output={ + "dbname": "db", + "host": "host" +}, no_env=no_env); + +test_case!(user_host_port, "postgresql://uri-user@host:12345/", output={ + "user": "uri-user", + "host": "host", + "port": "12345" +}, no_env=no_env); + +test_case!(user_host, "postgresql://uri-user@host/", output={ + "user": "uri-user", + "host": "host" +}, no_env=no_env); + +test_case!(user_only, "postgresql://uri-user@", output={ + "user": "uri-user" +}, no_env=no_env); + +test_case!(host_port, "postgresql://host:12345/", output={ + "host": "host", + "port": "12345" +}, no_env=no_env); + +test_case!(host_port_no_slash, "postgresql://host:12345", output={ + "host": "host", + "port": "12345" +}, no_env=no_env); + +test_case!(host_only, "postgresql://host/", output={ + "host": "host" +}, no_env=no_env); + +test_case!(host_no_slash, "postgresql://host", output={ + "host": "host" +}, no_env=no_env); + +test_case!(empty_uri, "postgresql://", output = {}, no_env = no_env); + +test_case!(hostaddr_only, "postgresql://?hostaddr=127.0.0.1", output={ + "hostaddr": "127.0.0.1" +}, no_env=no_env); + +test_case!(host_and_hostaddr, "postgresql://example.com?hostaddr=63.1.2.4", output={ + "host": "example.com", + "hostaddr": "63.1.2.4" +}, no_env=no_env); + +test_case!(percent_encoded_host, "postgresql://%68ost/", output={ + "host": "host" +}, no_env=no_env); + +test_case!(query_user, "postgresql://host/db?user=uri-user", output={ + "user": "uri-user", + "dbname": "db", + "host": "host" +}, no_env=no_env); + +test_case!(query_user_port, "postgresql://host/db?user=uri-user&port=12345", output={ + "user": "uri-user", + "dbname": "db", + "host": "host", + "port": "12345" +}, no_env=no_env); + +test_case!(query_percent_encoded_user, "postgresql://host/db?u%73er=someotheruser&port=12345", output={ + "user": "someotheruser", + "dbname": "db", + "host": "host", + "port": "12345" +}, no_env=no_env); + +test_case!( + invalid_percent_encoded_uzer, + "postgresql://host/db?u%7aer=someotheruser&port=12345", + output = {"uzer": "someotheruser", "host": "host", "dbname": "db", "port": "12345"}, + expect_libpq_mismatch = "Our library allows arbitrary params", + no_env = no_env +); + +test_case!(query_user_with_port, "postgresql://host:12345?user=uri-user", output={ + "user": "uri-user", + "host": "host", + "port": "12345" +}, no_env=no_env); + +test_case!(query_user_with_host, "postgresql://host?user=uri-user", output={ + "user": "uri-user", + "host": "host" +}, no_env=no_env); + +test_case!(empty_query, "postgresql://host?", output={ + "host": "host" +}, no_env=no_env); + +test_case!(ipv6_host_port_db, "postgresql://[::1]:12345/db", output={ + "dbname": "db", + "host": "::1", + "port": "12345" +}, no_env=no_env); + +test_case!(ipv6_host_db, "postgresql://[::1]/db", output={ + "dbname": "db", + "host": "::1" +}, no_env=no_env); + +test_case!(ipv6_host_full, "postgresql://[2001:db8::1234]/", output={ + "host": "2001:db8::1234" +}, no_env=no_env); + +test_case!( + invalid_ipv6_host, + "postgresql://[200z:db8::1234]/", + error = "", + expect_libpq_mismatch = "Invalid hosts are caught early", + no_env = no_env +); + +test_case!(ipv6_host_only, "postgresql://[::1]", output={ + "host": "::1" +}, no_env=no_env); + +test_case!(postgres_empty, "postgres://", output = {}, no_env = no_env); + +test_case!(postgres_root, "postgres:///", output = {}, no_env = no_env); + +test_case!(postgres_db_only, "postgres:///db", output={ + "dbname": "db" +}, no_env=no_env); + +test_case!(postgres_user_db, "postgres://uri-user@/db", output={ + "user": "uri-user", + "dbname": "db" +}, no_env=no_env); + +test_case!(postgres_socket_dir, "postgres://?host=/path/to/socket/dir", output={ + "host": "/path/to/socket/dir" +}, no_env=no_env); + +test_case!( + invalid_query_param, + "postgresql://host?uzer=", + output = { + "uzer": "", + "host": "host", + }, + expect_libpq_mismatch = "Arbitrary query params are supported", + no_env = no_env +); + +test_case!( + invalid_scheme, + "postgre://", + error = "missing \"=\" after \"postgre://\" in connection info string", + no_env = no_env +); + +test_case!(unclosed_ipv6, "postgres://[::1", error="end of string reached when looking for matching \"]\" in IPv6 host address in URI: \"postgres://[::1", no_env=no_env); + +test_case!( + empty_ipv6, + "postgres://[]", + error = "IPv6 host address may not be empty in URI: \"postgres://[]\"", + no_env = no_env +); + +test_case!(invalid_ipv6_end, "postgres://[::1]z", error="unexpected character \"z\" at position 17 in URI (expected \":\" or \"/\"): \"postgres://[::1]z\"", no_env=no_env); + +test_case!( + missing_query_value, + "postgresql://host?zzz", + error = "missing key/value separator \"=\" in URI query parameter: \"zzz\"", + no_env = no_env +); + +test_case!( + multiple_missing_values, + "postgresql://host?value1&value2", + error = "missing key/value separator \"=\" in URI query parameter: \"value1\"", + no_env = no_env +); + +test_case!( + extra_equals, + "postgresql://host?key=key=value", + error = "", + no_env = no_env +); + +test_case!( + invalid_percent_encoding, + "postgres://host?dbname=%XXfoo", + error = "invalid percent-encoded token: \"%XXfoo\"", + no_env = no_env +); + +test_case!( + null_in_percent_encoding, + "postgresql://a%00b", + error = "forbidden value %00 in percent-encoded value: \"a%00b\"", + no_env = no_env +); + +test_case!( + invalid_percent_encoding_zz, + "postgresql://%zz", + error = "invalid percent-encoded token: \"%zz\"", + no_env = no_env +); + +test_case!( + incomplete_percent_encoding_1, + "postgresql://%1", + error = "invalid percent-encoded token: \"%1\"", + no_env = no_env +); + +test_case!( + incomplete_percent_encoding_empty, + "postgresql://%", + error = "invalid percent-encoded token: \"%\"", + no_env = no_env +); + +test_case!(empty_user, "postgres://@host", output={ + "host": "host" +}, no_env=no_env); + +test_case!(empty_port, "postgres://host:/", output={ + "host": "host" +}, no_env=no_env); + +test_case!(port_only, "postgres://:12345/", output={ + "port": "12345" +}, no_env=no_env); + +test_case!(user_query_host, "postgres://otheruser@?host=/no/such/directory", output={ + "user": "otheruser", + "host": "/no/such/directory" +}, no_env=no_env); + +test_case!(user_query_host_with_slash, "postgres://otheruser@/?host=/no/such/directory", output={ + "user": "otheruser", + "host": "/no/such/directory" +}, no_env=no_env); + +test_case!(user_port_query_host, "postgres://otheruser@:12345?host=/no/such/socket/path", output={ + "user": "otheruser", + "host": "/no/such/socket/path", + "port": "12345" +}, no_env=no_env); + +test_case!(user_port_db_query_host, "postgres://otheruser@:12345/db?host=/path/to/socket", output={ + "user": "otheruser", + "dbname": "db", + "host": "/path/to/socket", + "port": "12345" +}, no_env=no_env); + +test_case!(port_db_query_host, "postgres://:12345/db?host=/path/to/socket", output={ + "dbname": "db", + "host": "/path/to/socket", + "port": "12345" +}, no_env=no_env); + +test_case!(port_query_host, "postgres://:12345?host=/path/to/socket", output={ + "host": "/path/to/socket", + "port": "12345" +}, no_env=no_env); + +test_case!(percent_encoded_path, "postgres://%2Fvar%2Flib%2Fpostgresql/dbname", output={ + "dbname": "dbname", + "host": "/var/lib/postgresql" +}, no_env=no_env); + +test_case!(sslmode_disable, "postgresql://host?sslmode=disable", output={ + "host": "host", + "sslmode": "disable" +}, no_env=no_env); + +// This one is challenging to test because of the sslmode defaults +// test_case!(sslmode_prefer, "postgresql://host?sslmode=prefer", output={ +// "host": "host", +// "sslmode": "prefer" +// }, no_env=no_env); + +// Intentional difference from libpq: this is what they do (from 001_uri.pl): +// +// "Usually the default sslmode is 'prefer' (for libraries with SSL) or +// 'disable' (for those without). This default changes to 'verify-full' if +// the system CA store is in use."" +test_case!(sslmode_verify_full, "postgresql://host?sslmode=verify-full", output={ + "host": "host", + "sslmode": "verify-full" +}, no_env=no_env); diff --git a/edb/server/pgrust/tests/real_postgres.rs b/edb/server/pgrust/tests/real_postgres.rs new file mode 100644 index 000000000000..db84f6cc72d9 --- /dev/null +++ b/edb/server/pgrust/tests/real_postgres.rs @@ -0,0 +1,375 @@ +// Constants +use openssl::ssl::{Ssl, SslContext, SslMethod}; +use pgrust::connection::tokio::TokioSocketAddress; +use pgrust::connection::{connect_raw_ssl, Authentication, ConnectionSslRequirement, Credentials}; +use rstest::rstest; +use std::io::{BufRead, BufReader, Write}; +use std::net::{Ipv4Addr, TcpListener}; +use std::os::unix::fs::PermissionsExt; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::sync::{Arc, RwLock}; +use std::thread; +use std::time::{Duration, Instant}; +use tempfile::TempDir; + +const STARTUP_TIMEOUT_DURATION: Duration = Duration::from_secs(30); +const PORT_RELEASE_TIMEOUT: Duration = Duration::from_secs(30); +const LINGER_DURATION: Duration = Duration::from_secs(1); +const HOT_LOOP_INTERVAL: Duration = Duration::from_millis(100); +const DEFAULT_USERNAME: &str = "username"; +const DEFAULT_PASSWORD: &str = "password"; +const DEFAULT_DATABASE: &str = "postgres"; + +/// Represents an ephemeral port that can be allocated and released for immediate re-use by another process. +struct EphemeralPort { + port: u16, + listener: Option, +} + +impl EphemeralPort { + /// Allocates a new ephemeral port. + /// + /// Returns a Result containing the EphemeralPort if successful, + /// or an IO error if the allocation fails. + fn allocate() -> std::io::Result { + let socket = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.set_linger(Some(LINGER_DURATION))?; + socket.bind(&std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, 0)).into())?; + socket.listen(1)?; + let listener = TcpListener::from(socket); + let port = listener.local_addr()?.port(); + Ok(EphemeralPort { + port, + listener: Some(listener), + }) + } + + /// Consumes the EphemeralPort and returns the allocated port number. + fn take(self) -> u16 { + // Drop the listener to free up the port + drop(self.listener); + + // Loop until the port is free + let start = Instant::now(); + + // If we can successfully connect to the port, it's not fully closed + while start.elapsed() < PORT_RELEASE_TIMEOUT { + let res = std::net::TcpStream::connect((Ipv4Addr::LOCALHOST, self.port)); + if res.is_err() { + // If connection fails, the port is released + break; + } + std::thread::sleep(HOT_LOOP_INTERVAL); + } + + self.port + } +} + +struct StdioReader { + output: Arc>, +} + +impl StdioReader { + fn spawn(reader: R, prefix: &'static str) -> Self { + let output = Arc::new(RwLock::new(String::new())); + let output_clone = Arc::clone(&output); + + thread::spawn(move || { + let mut buf_reader = std::io::BufReader::new(reader); + loop { + let mut line = String::new(); + match buf_reader.read_line(&mut line) { + Ok(0) => break, + Ok(_) => { + if let Ok(mut output) = output_clone.write() { + output.push_str(&line); + } + eprint!("[{}]: {}", prefix, line); + } + Err(e) => { + let error_line = format!("Error reading {}: {}\n", prefix, e); + if let Ok(mut output) = output_clone.write() { + output.push_str(&error_line); + } + eprintln!("{}", error_line); + } + } + } + }); + + StdioReader { output } + } + + fn contains(&self, s: &str) -> bool { + if let Ok(output) = self.output.read() { + output.contains(s) + } else { + false + } + } +} + +fn init_postgres(initdb: &Path, data_dir: &Path, auth: Authentication) -> std::io::Result<()> { + let mut pwfile = tempfile::NamedTempFile::new()?; + writeln!(pwfile, "{}", DEFAULT_PASSWORD)?; + let mut command = Command::new(initdb); + command + .arg("-D") + .arg(data_dir) + .arg("-A") + .arg(match auth { + Authentication::None => "trust", + Authentication::Password => "password", + Authentication::Md5 => "md5", + Authentication::ScramSha256 => "scram-sha-256", + }) + .arg("--pwfile") + .arg(pwfile.path()) + .arg("-U") + .arg(DEFAULT_USERNAME); + + let output = command.output()?; + + let status = output.status; + let output_str = String::from_utf8_lossy(&output.stdout).to_string(); + let error_str = String::from_utf8_lossy(&output.stderr).to_string(); + + eprintln!("initdb stdout:\n{}", output_str); + eprintln!("initdb stderr:\n{}", error_str); + + if !status.success() { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "initdb command failed", + )); + } + + Ok(()) +} + +fn run_postgres( + postgres_bin: &Path, + data_dir: &Path, + socket_path: &Path, + ssl: Option<(PathBuf, PathBuf)>, + port: u16, +) -> std::io::Result { + let mut command = Command::new(postgres_bin); + command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .arg("-D") + .arg(data_dir) + .arg("-k") + .arg(socket_path) + .arg("-h") + .arg(Ipv4Addr::LOCALHOST.to_string()) + .arg("-F") + // Useful for debugging + // .arg("-d") + // .arg("5") + .arg("-p") + .arg(port.to_string()); + + if let Some((cert_path, key_path)) = ssl { + let postgres_cert_path = data_dir.join("server.crt"); + let postgres_key_path = data_dir.join("server.key"); + std::fs::copy(cert_path, &postgres_cert_path)?; + std::fs::copy(key_path, &postgres_key_path)?; + // Set permissions for the certificate and key files + std::fs::set_permissions(&postgres_cert_path, std::fs::Permissions::from_mode(0o600))?; + std::fs::set_permissions(&postgres_key_path, std::fs::Permissions::from_mode(0o600))?; + + // Edit pg_hba.conf to change all "host" line prefixes to "hostssl" + let pg_hba_path = data_dir.join("pg_hba.conf"); + let content = std::fs::read_to_string(&pg_hba_path)?; + let modified_content = content + .lines() + .map(|line| { + if line.trim_start().starts_with("host") { + line.replacen("host", "hostssl", 1) + } else { + line.to_string() + } + }) + .collect::>() + .join("\n"); + eprintln!("pg_hba.conf:\n{modified_content}"); + std::fs::write(&pg_hba_path, modified_content)?; + + command.arg("-l"); + } + + let mut child = command.spawn()?; + + let stdout_reader = BufReader::new(child.stdout.take().expect("Failed to capture stdout")); + let _ = StdioReader::spawn(stdout_reader, "stdout"); + let stderr_reader = BufReader::new(child.stderr.take().expect("Failed to capture stderr")); + let stderr_reader = StdioReader::spawn(stderr_reader, "stderr"); + + let start_time = Instant::now(); + + let mut tcp_socket: Option = None; + let mut unix_socket: Option = None; + + let unix_socket_path = get_unix_socket_path(socket_path, port); + let tcp_socket_addr = std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, port)); + let mut db_ready = false; + + while start_time.elapsed() < STARTUP_TIMEOUT_DURATION { + std::thread::sleep(HOT_LOOP_INTERVAL); + match child.try_wait() { + Ok(Some(status)) => { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("PostgreSQL exited with status: {}", status), + )) + } + Err(e) => return Err(e), + _ => {} + } + if !db_ready && stderr_reader.contains("database system is ready to accept connections") { + eprintln!("Database is ready"); + db_ready = true; + } else { + continue; + } + if unix_socket.is_none() { + unix_socket = std::os::unix::net::UnixStream::connect(&unix_socket_path).ok(); + } + if tcp_socket.is_none() { + tcp_socket = std::net::TcpStream::connect(tcp_socket_addr).ok(); + } + if unix_socket.is_some() && tcp_socket.is_some() { + break; + } + } + + if unix_socket.is_some() && tcp_socket.is_some() { + return Ok(child); + } + + // Print status for TCP/unix sockets + if let Some(tcp) = &tcp_socket { + eprintln!( + "TCP socket at {tcp_socket_addr:?} bound successfully on {}", + tcp.local_addr()? + ); + } else { + eprintln!("TCP socket at {tcp_socket_addr:?} binding failed"); + } + + if unix_socket.is_some() { + eprintln!("Unix socket at {unix_socket_path:?} connected successfully"); + } else { + eprintln!("Unix socket at {unix_socket_path:?} connection failed"); + } + + Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "PostgreSQL failed to start within 30 seconds", + )) +} + +fn test_data_dir() -> std::path::PathBuf { + Path::new("../../../tests") + .canonicalize() + .expect("Failed to canonicalize tests directory path") +} + +fn postgres_bin_dir() -> std::path::PathBuf { + Path::new("../../../build/postgres/install/bin") + .canonicalize() + .expect("Failed to canonicalize postgres bin directory path") +} + +fn get_unix_socket_path(socket_path: &Path, port: u16) -> PathBuf { + socket_path.join(format!(".s.PGSQL.{}", port)) +} + +#[derive(Debug, Clone, Copy)] +enum Mode { + Tcp, + TcpSsl, + Unix, +} + +#[rstest] +#[tokio::test] +async fn test_auth_real( + #[values( + Authentication::None, + Authentication::Password, + Authentication::Md5, + Authentication::ScramSha256 + )] + auth: Authentication, + #[values(Mode::Tcp, Mode::TcpSsl, Mode::Unix)] mode: Mode, +) -> Result<(), Box> { + let initdb = postgres_bin_dir().join("initdb"); + let postgres = postgres_bin_dir().join("postgres"); + + if !initdb.exists() || !postgres.exists() { + println!("Skipping test: initdb or postgres not found"); + return Ok(()); + } + + let port = EphemeralPort::allocate()?; + let temp_dir = TempDir::new()?; + let data_dir = temp_dir.path().join("data"); + + init_postgres(&initdb, &data_dir, auth)?; + let ssl_key = match mode { + Mode::TcpSsl => { + let certs_dir = test_data_dir().join("certs"); + let cert = certs_dir.join("server.cert.pem"); + let key = certs_dir.join("server.key.pem"); + Some((cert, key)) + } + _ => None, + }; + + let port = port.take(); + let mut child = run_postgres(&postgres, &data_dir, &data_dir, ssl_key, port)?; + + let credentials = Credentials { + username: DEFAULT_USERNAME.to_string(), + password: DEFAULT_PASSWORD.to_string(), + database: DEFAULT_DATABASE.to_string(), + server_settings: Default::default(), + }; + let ssl = SslContext::builder(SslMethod::tls_client())?.build(); + let mut ssl = Ssl::new(&ssl)?; + ssl.set_connect_state(); + + let socket_address = match mode { + Mode::Unix => { + let path = get_unix_socket_path(&data_dir, port); + TokioSocketAddress::new_unix(path) + } + Mode::Tcp | Mode::TcpSsl => TokioSocketAddress::new_tcp((Ipv4Addr::LOCALHOST, port).into()), + }; + + let client = socket_address.connect().await?; + + let ssl_requirement = match mode { + Mode::TcpSsl => ConnectionSslRequirement::Required, + _ => ConnectionSslRequirement::Optional, + }; + + let params = connect_raw_ssl(credentials, ssl_requirement, ssl, client) + .await? + .params() + .clone(); + + assert_eq!(matches!(mode, Mode::TcpSsl), params.ssl); + assert_eq!(auth, params.auth); + + child.kill()?; + + Ok(()) +} diff --git a/edb/server/pgrust/tests/test_util/dsn_libpq.rs b/edb/server/pgrust/tests/test_util/dsn_libpq.rs new file mode 100644 index 000000000000..df7909c76db1 --- /dev/null +++ b/edb/server/pgrust/tests/test_util/dsn_libpq.rs @@ -0,0 +1,256 @@ +use std::collections::HashMap; +use std::ffi::{CStr, CString}; +use std::ptr; + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct _PQconninfoOption { + /// The keyword of the option + keyword: *mut libc::c_char, + /// Fallback environment variable name + envvar: *mut libc::c_char, + /// Fallback compiled in default value + compiled: *mut libc::c_char, + /// Option's current value, or NULL + val: *mut libc::c_char, + /// Label for field in connect dialog + label: *mut libc::c_char, + /// Indicates how to display this field in a connect dialog. + /// Values are: + /// - "": Display entered value as is + /// - "*": Password field - hide value + /// - "D": Debug option - don't show by default + dispchar: *mut libc::c_char, + /// Field size in characters for dialog + dispsize: libc::c_int, +} + +/// Rust-friendly version of PQconninfoOption +#[derive(Debug, Clone)] +#[allow(unused)] +pub struct PQConnInfoOption { + /// The keyword of the option + pub keyword: Option, + /// Fallback environment variable name + pub envvar: Option, + /// Fallback compiled in default value + pub compiled: Option, + /// Option's current value, or None + pub val: Option, + /// Label for field in connect dialog + pub label: Option, + /// Indicates how to display this field in a connect dialog. + /// Values are: + /// - "": Display entered value as is + /// - "*": Password field - hide value + /// - "D": Debug option - don't show by default + pub dispchar: Option, + /// Field size in characters for dialog + pub dispsize: i32, +} + +impl From<&_PQconninfoOption> for PQConnInfoOption { + fn from(option: &_PQconninfoOption) -> Self { + unsafe { + PQConnInfoOption { + keyword: (!option.keyword.is_null()).then(|| { + CStr::from_ptr(option.keyword) + .to_string_lossy() + .into_owned() + }), + envvar: (!option.envvar.is_null()) + .then(|| CStr::from_ptr(option.envvar).to_string_lossy().into_owned()), + compiled: (!option.compiled.is_null()).then(|| { + CStr::from_ptr(option.compiled) + .to_string_lossy() + .into_owned() + }), + val: (!option.val.is_null()) + .then(|| CStr::from_ptr(option.val).to_string_lossy().into_owned()), + label: (!option.label.is_null()) + .then(|| CStr::from_ptr(option.label).to_string_lossy().into_owned()), + dispchar: (!option.dispchar.is_null()).then(|| { + CStr::from_ptr(option.dispchar) + .to_string_lossy() + .into_owned() + }), + dispsize: option.dispsize, + } + } + } +} + +#[link(name = "pq")] +extern "C" { + /// Parses a connection string and returns the resulting connection options. + /// + /// This function parses a string in the same way as `PQconnectdb()` would, + /// and returns an array of connection options. If parsing fails, it returns NULL. + /// The returned options only include those explicitly specified in the string, + /// not any default values. + /// + /// # Arguments + /// + /// * `conninfo` - The connection string to parse. + /// * `errmsg` - If not NULL, it will be set to NULL on success, or a malloc'd + /// error message string on failure. The caller should free this string + /// using `PQfreemem()` when it's no longer needed. + /// + /// # Returns + /// + /// A pointer to a dynamically allocated array of `PQconninfoOption` structures, + /// or NULL on failure. In out-of-memory conditions, both `*errmsg` and the + /// return value could be NULL. + /// + /// # Safety + /// + /// The returned array should be freed using `PQconninfoFree()` when no longer needed. + fn PQconninfoParse( + conninfo: *const libc::c_char, + errmsg: *mut *mut libc::c_char, + ) -> *mut _PQconninfoOption; + + /// Constructs a default connection options array. + /// + /// This function identifies all available options and shows any default values + /// that are available from the environment, etc. On error (e.g., out of memory), + /// NULL is returned. + /// + /// Using this function, an application may determine all possible options + /// and their current default values. + /// + /// # Returns + /// + /// A pointer to a dynamically allocated array of `PQconninfoOption` structures, + /// or NULL on failure. + /// + /// # Safety + /// + /// The returned array should be freed using `PQconninfoFree()` when no longer needed. + /// + /// # Note + /// + /// As of PostgreSQL 7.0, the returned array is dynamically allocated. + /// Pre-7.0 applications that use this function will see a small memory leak + /// until they are updated to call `PQconninfoFree()`. + fn PQconndefaults() -> *mut _PQconninfoOption; + + /// Frees the data structure returned by `PQconndefaults()` or `PQconninfoParse()`. + /// + /// This function should be used to free the memory allocated by `PQconndefaults()` + /// or `PQconninfoParse()` when it's no longer needed. + /// + /// # Arguments + /// + /// * `connOptions` - A pointer to the `PQconninfoOption` structure to be freed. + /// + /// # Safety + /// + /// This function is unsafe because it operates on raw pointers. The caller must + /// ensure that the pointer is valid and points to a structure allocated by + /// `PQconndefaults()` or `PQconninfoParse()`. + fn PQconninfoFree(connOptions: *mut _PQconninfoOption); +} + +fn parse_conninfo_options(options: *mut _PQconninfoOption) -> Vec { + let mut result = Vec::new(); + let mut current = options; + + while !current.is_null() && unsafe { (*current).keyword != ptr::null_mut() } { + let option = unsafe { &*current }; + result.push(PQConnInfoOption::from(option)); + current = unsafe { current.add(1) }; + } + + result +} + +/// Parses a connection string and returns the resulting connection options. +/// +/// # Arguments +/// +/// * `conninfo` - The connection string to parse. +/// +/// # Returns +/// +/// A `Result` containing a `Vec` of `PQConnInfoOption` on success, or an error message on failure. +pub fn pq_conninfo_parse(conninfo: &str) -> Result, String> { + let c_conninfo = CString::new(conninfo).map_err(|e| e.to_string())?; + let mut errmsg: *mut libc::c_char = ptr::null_mut(); + + let options = unsafe { PQconninfoParse(c_conninfo.as_ptr(), &mut errmsg) }; + + if options.is_null() { + let error = if errmsg.is_null() { + "Unknown error occurred during parsing connection info".to_string() + } else { + let error_str = unsafe { CStr::from_ptr(errmsg) } + .to_string_lossy() + .into_owned(); + unsafe { libc::free(errmsg as *mut libc::c_void) }; + error_str + }; + return Err(error); + } + + let result = parse_conninfo_options(options); + unsafe { PQconninfoFree(options) }; + Ok(result) +} + +/// Constructs a default connection options array. +/// +/// # Returns +/// +/// A `Result` containing a `Vec` of `PQConnInfoOption` on success, or an error message on failure. +pub fn pq_conn_defaults() -> Result, String> { + let options = unsafe { PQconndefaults() }; + + if options.is_null() { + return Err("Failed to get default connection options".to_string()); + } + + let result = parse_conninfo_options(options); + unsafe { PQconninfoFree(options) }; + Ok(result) +} + +pub fn pq_conn_parse_non_defaults(urn: &str) -> Result, String> { + // Parse the given URN + let parsed_options = pq_conninfo_parse(urn)?; + + // Get the default options + let default_options = pq_conn_defaults()?; + + // Create a HashMap to store non-default values + let mut non_defaults = HashMap::new(); + + // Compare parsed options with defaults and store non-default values + for parsed in parsed_options { + if let Some(keyword) = &parsed.keyword { + if let Some(val) = &parsed.val { + // Find the corresponding default option + if let Some(default) = default_options + .iter() + .find(|&d| d.keyword.as_ref() == Some(keyword)) + { + // If the value is different from the default, add it to non_defaults + if default.val.as_ref() != Some(val) { + non_defaults.insert(keyword.clone(), val.clone()); + } + } else { + // If there's no corresponding default, it's a non-default value + non_defaults.insert(keyword.clone(), val.clone()); + } + } + } + } + + Ok(non_defaults) +} + +#[test] +fn test() { + std::env::set_var("PGUSER", "matt"); + eprintln!("{:#?}", pq_conn_parse_non_defaults("postgres://foo@/")); +} diff --git a/edb/server/pgrust/tests/test_util/mod.rs b/edb/server/pgrust/tests/test_util/mod.rs new file mode 100644 index 000000000000..b6dfb548e44d --- /dev/null +++ b/edb/server/pgrust/tests/test_util/mod.rs @@ -0,0 +1,204 @@ +use pgrust::connection::{ + dsn::parse_postgres_dsn, parse_postgres_dsn_env, RawConnectionParameters, +}; +use std::collections::HashMap; + +mod dsn_libpq; +macro_rules! assert_eq_map { + ($left:expr, $right:expr $(, $($arg:tt)*)?) => {{ + fn make_string(s: impl AsRef) -> String { + s.as_ref().to_string() + } + let left: HashMap<_, _> = $left.clone().into(); + let right: HashMap<_, _> = $right.clone().into(); + let left: std::collections::BTreeMap = left + .into_iter() + .map(|(k, v)| (make_string(k), make_string(v))) + .collect(); + let right: std::collections::BTreeMap = right + .into_iter() + .map(|(k, v)| (make_string(k), make_string(v))) + .collect(); + + pretty_assertions::assert_eq!(left, right $(, $($arg)*)?); + }}; +} + +#[track_caller] +pub(crate) fn test( + dsn: &str, + expected: HashMap, + env: HashMap, + expect_mismatch: bool, + no_env: bool, +) { + eprintln!("DSN: {dsn:?}"); + + let mut ours_no_env = match parse_postgres_dsn(dsn) { + Err(res) => panic!("Expected test to pass {dsn:?}, but instead failed:\n{res:#?}"), + Ok(res) => res, + }; + + eprintln!("Parsed: {ours_no_env:#?}"); + let ours: HashMap = ours_no_env.clone().into(); + eprintln!("Parsed (map): {ours:#?}"); + + let url = ours_no_env.to_url(); + let roundtrip = match parse_postgres_dsn(&url) { + Err(res) => { + panic!("Expected roundtripped URL to pass {url:?}, but instead failed:\n{res:#?}") + } + Ok(res) => res, + }; + assert_eq_map!( + roundtrip, + ours_no_env, + "Did not maintain fidelity through the roundtrip! ({url:?})" + ); + + if no_env { + assert_eq_map!( + expected, + ours_no_env, + "crate mismatch from expected when parsing {dsn:?}" + ); + } else { + let ours = match parse_postgres_dsn_env(dsn, env) { + Err(res) => panic!("Expected test to pass {dsn:?}, but instead failed:\n{res:#?}"), + Ok(res) => res, + }; + + // Avoid the hassle of specifying the default SSL mode unless explicitly tested for. + let mut ours: HashMap = RawConnectionParameters::from(ours).into(); + if !expected.contains_key("sslmode") { + ours.remove("sslmode"); + } + + assert_eq_map!( + expected, + ours, + "crate mismatch from expected when parsing {dsn:?}" + ); + } + + let res = dsn_libpq::pq_conn_parse_non_defaults(dsn); + eprintln!("{res:?}"); + if expect_mismatch { + assert!(res.is_err()); + } else { + let libpq = match res { + Err(res) => panic!("Expected test to pass {dsn:?}, but instead failed:\n{res:#?}"), + Ok(res) => res, + }; + + // Only compare for no_env + if no_env { + assert_eq_map!( + libpq, + expected, + "libpq mismatch from expected when parsing {dsn:?}" + ); + } else { + // We cannot detect libpq's defaults here so we just remove them + // from the test + if ours_no_env.port == Some(vec![Some(5432)]) { + ours_no_env.port = None + } + assert_eq_map!( + libpq, + ours_no_env, + "libpq mismatch from expected when parsing {dsn:?}" + ); + } + } +} + +#[track_caller] +pub(crate) fn test_fail( + dsn: &str, + env: HashMap, + expect_mismatch: bool, + no_env: bool, +) { + let res = dsn_libpq::pq_conn_parse_non_defaults(dsn); + eprintln!("libpq: {res:#?}"); + if expect_mismatch { + assert!(res.is_ok()); + } else if let Ok(res) = res { + panic!("Expected test to fail {dsn:?}, but instead parsed correctly:\n{res:#?}") + } + if no_env { + match parse_postgres_dsn(dsn) { + Ok(res) => { + panic!("Expected test to fail {dsn:?}, but instead parsed correctly:\n{res:#?}") + } + Err(e) => { + eprintln!("Error: {e:#?}") + } + } + } else { + match parse_postgres_dsn_env(dsn, env) { + Ok(res) => { + panic!("Expected test to fail {dsn:?}, but instead parsed correctly:\n{res:#?}") + } + Err(e) => { + eprintln!("Error: {e:#?}") + } + } + } +} + +#[macro_export] +macro_rules! env { + ({ $($key:literal : $value:expr),* $(,)? }) => {{ + #[allow(unused_mut)] + let mut map = std::collections::HashMap::new(); + $( + map.insert($key.to_string(), $value.to_string()); + )* + map + }}; + () => { + std::collections::HashMap::new() + }; +} +pub use env; + +#[macro_export] +macro_rules! test_case { + ($name:ident, $urn:literal, output=$output:tt $( , expect_libpq_mismatch=$reason:literal )? $( , no_env=$no_env:ident )?) => { + paste::paste!( #[test] fn [< test_ $name >]() { + let expect_libpq_mismatch: &[&'static str] = &[$($reason)?]; + let no_env: &[&'static str] = &[$(stringify!($no_env))?]; + $crate::test_util::test($urn, $crate::test_util::env!($output), $crate::test_util::env!({}), expect_libpq_mismatch.len() > 0, no_env.len() > 0) + } ); + }; + ($name:ident, $urn:literal, output=$output:tt, extra=$extra:tt $( , expect_libpq_mismatch=$reason:literal )? $( , no_env=$no_env:ident )?) => { + paste::paste!( #[test] fn [< test_ $name >]() { + let expect_libpq_mismatch: &[&'static str] = &[$($reason)?]; + let no_env: &[&'static str] = &[$(stringify!($no_env))?]; + $crate::test_util::test($urn, $crate::test_util::env!($output),$crate::test_util::env!({}), expect_libpq_mismatch.len() > 0, no_env.len() > 0) + } ); + }; + ($name:ident, $urn:literal, error=$error:tt $( , expect_libpq_mismatch=$reason:literal )? $( , no_env=$no_env:ident )?) => { + paste::paste!( #[test] fn [< test_ $name >]() { + let expect_libpq_mismatch: &[&'static str] = &[$($reason)?]; + let no_env: &[&'static str] = &[$(stringify!($no_env))?]; + $crate::test_util::test_fail($urn, $crate::test_util::env!({}), expect_libpq_mismatch.len() > 0, no_env.len() > 0) + } ); + }; + ($name:ident, $urn:literal, env=$env:tt, output=$output:tt $( , expect_libpq_mismatch=$reason:literal )? $( , no_env=$no_env:ident )?) => { + paste::paste!( #[test] fn [< test_ $name >]() { + let expect_libpq_mismatch: &[&'static str] = &[$($reason)?]; + let no_env: &[&'static str] = &[$(stringify!($no_env))?]; + $crate::test_util::test($urn, $crate::test_util::env!($output), $crate::test_util::env!($env), expect_libpq_mismatch.len() > 0, no_env.len() > 0) + } ); + }; + ($name:ident, $urn:literal, env=$env:tt, error=$output:tt $( , expect_libpq_mismatch=$reason:literal )? $( , no_env=$no_env:ident )?) => { + paste::paste!( #[test] fn [< test_ $name >]() { + let expect_libpq_mismatch: &[&'static str] = &[$($reason)?]; + let no_env: &[&'static str] = &[$(stringify!($no_env))?]; + $crate::test_util::test_fail($urn, $crate::test_util::env!($env), expect_libpq_mismatch.len() > 0, no_env.len() > 0) + } ); + }; +} diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index dec7613c1bdd..c4a32d36f477 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -320,17 +320,9 @@ cdef class EdgeConnection(frontend.FrontendConnection): self.write(buf) + # In dev mode we expose the backend postgres DSN if self.server.in_dev_mode(): - pgaddr = dict(self.tenant.get_pgaddr()) - if pgaddr.get('password'): - pgaddr['password'] = '********' - pgaddr['database'] = self.tenant.get_pg_dbname( - self.get_dbview().dbname - ) - pgaddr.pop('ssl', None) - if 'sslmode' in pgaddr: - pgaddr['sslmode'] = pgaddr['sslmode'].name - self.write_status(b'pgaddr', json.dumps(pgaddr).encode()) + self.write_status(b'pgdsn', self.tenant.get_pgaddr().encode()) self.write_status( b'suggested_pool_concurrency', diff --git a/edb/server/render_dsn.py b/edb/server/render_dsn.py deleted file mode 100644 index 4963f127e235..000000000000 --- a/edb/server/render_dsn.py +++ /dev/null @@ -1,57 +0,0 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from __future__ import annotations - -import urllib.parse - - -def render_dsn(scheme, params): - params = dict(params) - dsn = params.pop('dsn', '') - if dsn: - return dsn - - user = params.pop('user', '') - if user: - password = params.pop('password', '') - if password: - user += f':{password}' - - if user: - user += '@' - - host = params.pop('host', 'localhost') - if '/' in host: - # Put host back, it's a UNIX socket path, needs to be - # in query part. - params['host'] = host - host = '' - port = '' - else: - port = params.pop('port') - if port: - port = f':{port}' - - if params: - query = '?' + urllib.parse.urlencode(params) - else: - query = '' - - return f'{scheme}://{user}{host}{port}{query}' diff --git a/edb/server/tenant.py b/edb/server/tenant.py index e2a67fce465f..09269b48cb05 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -26,7 +26,6 @@ Mapping, Coroutine, AsyncGenerator, - Dict, Set, Optional, TypedDict, @@ -277,8 +276,8 @@ def suggested_client_pool_size(self) -> int: def get_pg_dbname(self, dbname: str) -> str: return self._cluster.get_db_name(dbname) - def get_pgaddr(self) -> Dict[str, Any]: - return self._cluster.get_connection_spec() + def get_pgaddr(self) -> str: + return self._cluster.get_pgaddr() @functools.lru_cache def get_backend_runtime_params(self) -> pgparams.BackendRuntimeParams: @@ -350,7 +349,10 @@ async def _fetch_roles(self, syscon: pgcon.PGConnection) -> None: async def init_sys_pgcon(self) -> None: self._sys_pgcon_waiter = asyncio.Lock() - self.__sys_pgcon = await self._pg_connect(defines.EDGEDB_SYSTEM_DB) + self.__sys_pgcon = await self._pg_connect( + defines.EDGEDB_SYSTEM_DB, + source_description="init_sys_pgcon", + ) self._sys_pgcon_ready_evt = asyncio.Event() self._sys_pgcon_reconnect_evt = asyncio.Event() @@ -537,7 +539,11 @@ def terminate_sys_pgcon(self) -> None: self.__sys_pgcon = None del self._sys_pgcon_waiter - async def _pg_connect(self, dbname: str) -> pgcon.PGConnection: + async def _pg_connect( + self, + dbname: str, + source_description: str="pool connection" + ) -> pgcon.PGConnection: ha_serial = self._ha_master_serial if self.get_backend_runtime_params().has_create_database: pg_dbname = self.get_pg_dbname(dbname) @@ -545,8 +551,10 @@ async def _pg_connect(self, dbname: str) -> pgcon.PGConnection: pg_dbname = self.get_pg_dbname(defines.EDGEDB_SUPERUSER_DB) started_at = time.monotonic() try: - rv = await pgcon.connect( - self.get_pgaddr(), pg_dbname, self.get_backend_runtime_params() + rv = await self._cluster.connect( + source_description=source_description, + database=pg_dbname, + apply_init_script=True ) if self._server.stmt_cache_size is not None: rv.set_stmt_cache_size(self._server.stmt_cache_size) @@ -592,7 +600,10 @@ async def direct_pgcon( ) -> AsyncGenerator[pgcon.PGConnection, None]: conn = None try: - conn = await self._pg_connect(dbname) + conn = await self._pg_connect( + dbname, + source_description="direct_pgcon" + ) yield conn finally: if conn is not None: @@ -704,7 +715,10 @@ async def _reconnect_sys_pgcon(self) -> None: # 1. This tenant is still running # 2. We still cannot connect to the Postgres cluster try: - conn = await self._pg_connect(defines.EDGEDB_SYSTEM_DB) + conn = await self._pg_connect( + defines.EDGEDB_SYSTEM_DB, + source_description="_reconnect_sys_pgcon" + ) break except OSError: pass @@ -1689,9 +1703,7 @@ def get_debug_info(self) -> dict[str, Any]: instance_config=config.debug_serialize_config( self.get_sys_config()), user_roles=self._roles, - pg_addr={ - k: v for k, v in self.get_pgaddr().items() if k not in ["ssl"] - }, + pg_addr=vars(self._cluster.get_connection_params()), pg_pool=self._pg_pool._build_snapshot(now=time.monotonic()), ) diff --git a/edb/testbase/connection.py b/edb/testbase/connection.py index 906181cdb161..58cd5f0568df 100644 --- a/edb/testbase/connection.py +++ b/edb/testbase/connection.py @@ -329,10 +329,10 @@ class Connection(options._OptionsMixin, abstract.AsyncIOExecutor): def __init__( self, connect_args, *, test_no_tls=False, server_hostname=None - ): + ) -> None: super().__init__() self._connect_args = connect_args - self._protocol = None + self._protocol: typing.Optional[edb_protocol.Protocol] = None self._transport = None self._query_cache = abstract.QueryCache( codecs_registry=protocol.CodecsRegistry(), @@ -341,7 +341,7 @@ def __init__( self._test_no_tls = test_no_tls self._params = None self._server_hostname = server_hostname - self._log_listeners = set() + self._log_listeners: set[typing.Any] = set() def add_log_listener(self, callback): self._log_listeners.add(callback) diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 68be4976860a..d8f827bbdf70 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -68,7 +68,7 @@ from edb.server import cluster as edgedb_cluster from edb.server import pgcluster from edb.server import defines as edgedb_defines -from edb.server import pgconnparams +from edb.server.pgconnparams import ConnectionParams from edb.common import assert_data_shape from edb.common import devmode @@ -1196,7 +1196,7 @@ def get_backend_sql_dsn(cls, dbname=None): password = None spec_dsn = os.environ.get('EDGEDB_TEST_BACKEND_DSN') if spec_dsn: - _, params = pgconnparams.parse_dsn(spec_dsn) + params = ConnectionParams(dsn=spec_dsn) password = params.password if dbname is None: @@ -2340,16 +2340,11 @@ async def __aenter__(self): ]) elif self.adjacent_to is not None: settings = self.adjacent_to.get_settings() - pgaddr = settings.get('pgaddr') - if pgaddr is None: - raise RuntimeError('test requires devmode') - pgaddr = json.loads(pgaddr) - pgdsn = ( - f'postgres:///?user={pgaddr["user"]}&port={pgaddr["port"]}' - f'&host={pgaddr["host"]}' - ) + pgdsn = settings.get('pgdsn') + if pgdsn is None: + raise RuntimeError('test requires devmode to access pgdsn') cmd += [ - '--backend-dsn', pgdsn + '--backend-dsn', pgdsn.decode('utf-8') ] elif self.multitenant_config: cmd += ['--multitenant-config-file', self.multitenant_config] diff --git a/edb/tools/wipe.py b/edb/tools/wipe.py index 108a7db28de9..aa62d64f9ae7 100644 --- a/edb/tools/wipe.py +++ b/edb/tools/wipe.py @@ -35,7 +35,6 @@ from edb.server import compiler as edbcompiler from edb.server import defines as edbdef from edb.server import pgcluster -from edb.server import pgconnparams from edb.pgsql import common as pgcommon from edb.pgsql.common import quote_ident as qi @@ -118,11 +117,9 @@ async def do_wipe( data_dir, tenant_id='', ) - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='template1', - ), + cluster.update_connection_params( + user='postgres', + database='template1', ) else: raise click.UsageError( @@ -197,7 +194,10 @@ async def wipe_tenant( ) try: - tpl_conn = await cluster.connect(database=tpl_db) + tpl_conn = await cluster.connect( + database=tpl_db, + source_description="wipe_tenant", + ) except pgcon.BackendCatalogNameError: click.secho( f'Instance tenant {tenant!r} does not have the ' diff --git a/pyproject.toml b/pyproject.toml index 4fc3cb7a9934..a417078c0373 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ show_column_numbers = true show_error_codes = true # This being an error seems super confused to me. disable_error_code = "type-abstract" +enable_incomplete_feature = ["Unpack"] [[tool.mypy.overrides]] module = [ diff --git a/tests/test_backend_connect.py b/tests/test_backend_connect.py index 773744c24d4a..ece5cca6b8c0 100644 --- a/tests/test_backend_connect.py +++ b/tests/test_backend_connect.py @@ -16,11 +16,10 @@ # limitations under the License. # import warnings -from typing import Optional +from typing import Optional, Unpack import asyncio import contextlib -import dataclasses import ipaddress import os import pathlib @@ -32,7 +31,6 @@ import stat import sys import tempfile -import textwrap import unittest import unittest.mock import urllib.parse @@ -140,11 +138,8 @@ def setUpClass(cls): async def init_temp_cluster(cls): cluster = cls.cluster = TempCluster() await cluster.lookup_postgres() - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='postgres', - ), + cluster.update_connection_params( + user='postgres', ) await cluster.init(**_get_initdb_options({})) await cluster.trust_local_connections() @@ -177,37 +172,15 @@ def tearDownClass(cls): super().tearDownClass() @classmethod - def get_connection_spec(cls, kwargs=None): - if not kwargs: - kwargs = {} - conn_spec = cls.cluster.get_connection_spec() - conn_spec['host'] = 'localhost' - if kwargs.get('dsn'): - _addrs, params = pgconnparams.parse_dsn(kwargs['dsn']) - for k in ( - 'user', - 'password', - 'database', - 'ssl', - 'sslmode', - 'server_settings', - ): - v = getattr(params, k) - if v is not None: - conn_spec[k] = v - conn_spec.update(kwargs) - if not os.environ.get('PGHOST') and not kwargs.get('dsn'): - if 'database' not in conn_spec: - conn_spec['database'] = 'postgres' - if 'user' not in conn_spec: - conn_spec['user'] = 'postgres' - return conn_spec - - @classmethod - def connect(cls, **kwargs): - conn_spec = cls.get_connection_spec(kwargs) - return pgcon.connect( - conn_spec, cls.dbname, cls.cluster.get_runtime_params() + async def connect(cls, **kwargs: Unpack[pgconnparams.CreateParamsKwargs]): + import inspect + assert cls.cluster is not None + source_description = ("ClusterTestCase: " + f"{inspect.currentframe().f_back.f_code.co_name}") # type: ignore + kwargs['database'] = cls.dbname + return await cls.cluster.connect( + source_description=source_description, + **kwargs ) def setUp(self): @@ -217,7 +190,8 @@ def setUp(self): def tearDown(self): try: - self.con.terminate() + if self.con: + self.con.terminate() self.con = None finally: super().tearDown() @@ -333,10 +307,9 @@ async def test_auth_reject(self): await self.connect(user='reject_user') async def test_auth_password_cleartext(self): - with self.assertRaisesRegex(RuntimeError, 'unsupported auth method'): - await self.connect( - user='password_user', - password='correctpassword') + await self.connect( + user='password_user', + password='correctpassword') async def test_auth_password_md5(self): conn = await self.connect( @@ -403,384 +376,6 @@ async def test_auth_unsupported(self): class TestConnectParams(tb.TestCase): - TESTS = [ - { - 'name': 'all_env_default_ssl', - 'env': { - 'PGUSER': 'user', - 'PGDATABASE': 'testdb', - 'PGPASSWORD': 'passw', - 'PGHOST': 'host', - 'PGPORT': '123', - 'PGCONNECT_TIMEOUT': '8', - }, - 'result': ([('host', 123)], { - 'user': 'user', - 'password': 'passw', - 'database': 'testdb', - 'ssl': True, - 'sslmode': SSLMode.prefer, - 'connect_timeout': 8, - }) - }, - - { - 'name': 'dsn_override_env', - 'env': { - 'PGUSER': 'user', - 'PGDATABASE': 'testdb', - 'PGPASSWORD': 'passw', - 'PGHOST': 'host', - 'PGPORT': '123', - 'PGCONNECT_TIMEOUT': '8', - }, - - 'dsn': 'postgres://user2:passw2@host2:456/db2?connect_timeout=6', - - 'result': ([('host2', 456)], { - 'user': 'user2', - 'password': 'passw2', - 'database': 'db2', - 'connect_timeout': 6, - }) - }, - - { - 'name': 'dsn_override_env_ssl', - 'env': { - 'PGUSER': 'user', - 'PGDATABASE': 'testdb', - 'PGPASSWORD': 'passw', - 'PGHOST': 'host', - 'PGPORT': '123', - 'PGSSLMODE': 'allow' - }, - - 'dsn': 'postgres://user2:passw2@host2:456/db2?sslmode=disable', - - 'result': ([('host2', 456)], { - 'user': 'user2', - 'password': 'passw2', - 'database': 'db2', - 'sslmode': SSLMode.disable, - 'ssl': None}) - }, - - { - 'name': 'dsn_overrides_env_partially', - 'env': { - 'PGUSER': 'user', - 'PGDATABASE': 'testdb', - 'PGPASSWORD': 'passw', - 'PGHOST': 'host', - 'PGPORT': '123', - 'PGSSLMODE': 'allow' - }, - - 'dsn': 'postgres://user3:123123@localhost:5555/abcdef', - - 'result': ([('localhost', 5555)], { - 'user': 'user3', - 'password': '123123', - 'database': 'abcdef', - 'ssl': True, - 'sslmode': SSLMode.allow}) - }, - - { - 'name': 'dsn_override_env_ssl_prefer', - 'env': { - 'PGUSER': 'user', - 'PGDATABASE': 'testdb', - 'PGPASSWORD': 'passw', - 'PGHOST': 'host', - 'PGPORT': '123', - 'PGSSLMODE': 'prefer' - }, - - 'dsn': 'postgres://user2:passw2@host2:456/db2?sslmode=disable', - - 'result': ([('host2', 456)], { - 'user': 'user2', - 'password': 'passw2', - 'database': 'db2', - 'sslmode': SSLMode.disable, - 'ssl': None}) - }, - - { - 'name': 'dsn_overrides_env_partially_ssl_prefer', - 'env': { - 'PGUSER': 'user', - 'PGDATABASE': 'testdb', - 'PGPASSWORD': 'passw', - 'PGHOST': 'host', - 'PGPORT': '123', - 'PGSSLMODE': 'prefer' - }, - - 'dsn': 'postgres://user3:123123@localhost:5555/abcdef', - - 'result': ([('localhost', 5555)], { - 'user': 'user3', - 'password': '123123', - 'database': 'abcdef', - 'ssl': True, - 'sslmode': SSLMode.prefer}) - }, - - { - 'name': 'dsn_only', - 'dsn': 'postgres://user3:123123@localhost:5555/abcdef', - 'result': ([('localhost', 5555)], { - 'user': 'user3', - 'password': '123123', - 'database': 'abcdef'}) - }, - - { - 'name': 'dsn_only_multi_host', - 'dsn': 'postgresql://user@host1,host2/db', - 'result': ([('host1', 5432), ('host2', 5432)], { - 'database': 'db', - 'user': 'user', - }) - }, - - { - 'name': 'dsn_only_multi_host_and_port', - 'dsn': 'postgresql://user@host1:1111,host2:2222/db', - 'result': ([('host1', 1111), ('host2', 2222)], { - 'database': 'db', - 'user': 'user', - }) - }, - - { - 'name': 'dsn_combines_env_multi_host', - 'env': { - 'PGHOST': 'host1:1111,host2:2222', - 'PGUSER': 'foo', - }, - 'dsn': 'postgresql:///db', - 'result': ([('host1', 1111), ('host2', 2222)], { - 'database': 'db', - 'user': 'foo', - }) - }, - - { - 'name': 'dsn_multi_host_combines_env', - 'env': { - 'PGUSER': 'foo', - }, - 'dsn': 'postgresql:///db?host=host1:1111,host2:2222', - 'result': ([('host1', 1111), ('host2', 2222)], { - 'database': 'db', - 'user': 'foo', - }) - }, - - { - 'name': 'params_multi_host_dsn_env_mix', - 'env': { - 'PGUSER': 'foo', - }, - 'dsn': 'postgresql://host1,host2/db', - 'result': ([('host1', 5432), ('host2', 5432)], { - 'database': 'db', - 'user': 'foo', - }) - }, - - { - 'name': 'dsn_settings_override_and_ssl', - 'dsn': 'postgresql://me:ask@127.0.0.1:888/' - 'db?param=sss¶m=123&host=testhost&user=testuser' - '&port=2222&database=testdb&sslmode=require', - 'result': ([('127.0.0.1', 888)], { - 'server_settings': {'param': '123'}, - 'user': 'me', - 'password': 'ask', - 'database': 'db', - 'ssl': True, - 'sslmode': SSLMode.require}) - }, - - { - 'name': 'multiple_settings', - 'dsn': 'postgresql://me:ask@127.0.0.1:888/' - 'db?param=sss¶m=123&host=testhost&user=testuser' - '&port=2222&database=testdb&sslmode=verify_full' - '&aa=bb', - 'result': ([('127.0.0.1', 888)], { - 'server_settings': {'aa': 'bb', 'param': '123'}, - 'user': 'me', - 'password': 'ask', - 'database': 'db', - 'sslmode': SSLMode.verify_full, - 'ssl': True}) - }, - - { - 'name': 'dsn_only_unix', - 'dsn': 'postgresql:///dbname?host=/unix_sock/test&user=spam', - 'result': ([('/unix_sock/test', 5432)], { - 'user': 'spam', - 'database': 'dbname'}) - }, - - { - 'name': 'dsn_only_quoted', - 'dsn': 'postgresql://us%40r:p%40ss@h%40st1,h%40st2:543%33/d%62', - 'result': ( - [('h@st1', 5432), ('h@st2', 5433)], - { - 'user': 'us@r', - 'password': 'p@ss', - 'database': 'db', - } - ) - }, - - { - 'name': 'dsn_only_unquoted_host', - 'dsn': 'postgresql://user:p@ss@host/db', - 'result': ( - [('ss@host', 5432)], - { - 'user': 'user', - 'password': 'p', - 'database': 'db', - } - ) - }, - - { - 'name': 'dsn_only_quoted_params', - 'dsn': 'postgresql:///d%62?user=us%40r&host=h%40st&port=543%33', - 'result': ( - [('h@st', 5433)], - { - 'user': 'us@r', - 'database': 'db', - } - ) - }, - - { - 'name': 'dsn_ipv6_multi_host', - 'dsn': 'postgresql://user@[2001:db8::1234%25eth0],[::1]/db', - 'result': ([('2001:db8::1234%eth0', 5432), ('::1', 5432)], { - 'database': 'db', - 'user': 'user', - }) - }, - - { - 'name': 'dsn_ipv6_multi_host_port', - 'dsn': 'postgresql://user@[2001:db8::1234]:1111,[::1]:2222/db', - 'result': ([('2001:db8::1234', 1111), ('::1', 2222)], { - 'database': 'db', - 'user': 'user', - }) - }, - - { - 'name': 'dsn_ipv6_multi_host_query_part', - 'dsn': 'postgresql:///db?user=user&host=[2001:db8::1234],[::1]', - 'result': ([('2001:db8::1234', 5432), ('::1', 5432)], { - 'database': 'db', - 'user': 'user', - }) - }, - - { - 'name': 'dsn_only_illegal_protocol', - 'dsn': 'pq:///dbname?host=/unix_sock/test&user=spam', - 'error': (ValueError, 'Invalid DSN.*') - }, - { - 'name': 'env_ports_mismatch_dsn_multi_hosts', - 'dsn': 'postgresql://host1,host2,host3/db', - 'env': {'PGPORT': '111,222'}, - 'error': ( - ValueError, - 'Unexpected number of ports.*' - ) - }, - { - 'name': 'dsn_only_quoted_unix_host_port_in_params', - 'dsn': 'postgres://user@?port=56226&host=%2Ftmp', - 'result': ( - [('/tmp', 56226)], - { - 'user': 'user', - 'database': 'user', - 'sslmode': SSLMode.disable, - 'ssl': None - } - ) - }, - { - 'name': 'dsn_only_cloudsql', - 'dsn': 'postgres:///db?host=/cloudsql/' - 'project:region:instance-name&user=spam', - 'result': ( - [( - '/cloudsql/project:region:instance-name', - 5432, - )], { - 'user': 'spam', - 'database': 'db' - } - ) - }, - { - 'name': 'dsn_only_cloudsql_unix_and_tcp', - 'dsn': 'postgres:///db?host=127.0.0.1:5432,/cloudsql/' - 'project:region:instance-name,localhost:5433&user=spam', - 'result': ( - [ - ('127.0.0.1', 5432), - ( - '/cloudsql/project:region:instance-name', - 5432, - ), - ('localhost', 5433) - ], { - 'user': 'spam', - 'database': 'db', - 'ssl': True, - 'sslmode': SSLMode.prefer, - } - ) - }, - *[ - { - 'name': f'connect_timeout_{given}', - 'dsn': f'postgres://spam@127.0.0.1:5432/postgres?' - f'connect_timeout={given}', - 'result': ( - [('127.0.0.1', 5432)], - { - 'user': 'spam', - 'database': 'postgres', - 'connect_timeout': expected, - } - ) - } - for given, expected in [ - ('-8', None), - ('-1', None), - ('0', None), - ('1', 2), - ('2', 2), - ('3', 3), - ] - ], - ] - @contextlib.contextmanager def environ(self, **kwargs): old_vals = {} @@ -831,27 +426,16 @@ def run_testcase(self, testcase): if expected_error: es.enter_context(self.assertRaisesRegex(*expected_error)) - addrs, conn_params = pgconnparams.parse_dsn(dsn=dsn) - - params = {} - for k in dataclasses.fields(conn_params): - k = k.name - v = getattr(conn_params, k) - if v or (expected is not None and k in expected[1]): - params[k] = v + conn_params = pgconnparams.ConnectionParams(dsn=dsn) + conn_params = conn_params.resolve() - if isinstance(params.get('ssl'), ssl.SSLContext): - params['ssl'] = True - - result = (list(addrs), params) + to_dict = conn_params.__dict__ + host = to_dict.pop('host', None).split(',') + port = map(int, to_dict.pop('port', None).split(',')) + to_dict.pop('sslmode', None) + result = (list(zip(host, port)), to_dict) if expected is not None: - if 'ssl' not in expected[1]: - # Avoid the hassle of specifying the default SSL mode - # unless explicitly tested for. - params.pop('ssl', None) - params.pop('sslmode', None) - self.assertEqual( expected, result, @@ -898,149 +482,6 @@ def test_test_connect_params_run_testcase(self): ) }) - def test_connect_params(self): - with mock_dot_postgresql(): - for testcase in self.TESTS: - self.run_testcase(testcase) - - def test_connect_pgpass_regular(self): - passfile = tempfile.NamedTemporaryFile('w+t', delete=False) - passfile.write(textwrap.dedent(R''' - abc:*:*:user:password from pgpass for user@abc - localhost:*:*:*:password from pgpass for localhost - cde:5433:*:*:password from pgpass for cde:5433 - - *:*:*:testuser:password from pgpass for testuser - *:*:testdb:*:password from pgpass for testdb - # comment - *:*:test\:db:test\\:password from pgpass with escapes - ''')) - passfile.close() - os.chmod(passfile.name, stat.S_IWUSR | stat.S_IRUSR) - - try: - # passfile path in env - self.run_testcase({ - 'env': { - 'PGPASSFILE': passfile.name - }, - 'dsn': 'postgres://user@abc/db', - 'result': ( - [('abc', 5432)], - { - 'password': 'password from pgpass for user@abc', - 'user': 'user', - 'database': 'db', - } - ) - }) - - # passfile path in dsn - self.run_testcase({ - 'dsn': 'postgres://user@abc/db?passfile={}'.format( - passfile.name), - 'result': ( - [('abc', 5432)], - { - 'password': 'password from pgpass for user@abc', - 'user': 'user', - 'database': 'db', - } - ) - }) - - self.run_testcase({ - 'dsn': 'postgres://user@localhost/db?passfile={}'.format( - passfile.name - ), - 'result': ( - [('localhost', 5432)], - { - 'password': 'password from pgpass for localhost', - 'user': 'user', - 'database': 'db', - } - ) - }) - - # unix socket gets normalized as localhost - self.run_testcase({ - 'dsn': 'postgres:///db?user=user&host=/tmp&passfile={}'.format( - passfile.name - ), - 'result': ( - [('/tmp', 5432)], - { - 'password': 'password from pgpass for localhost', - 'user': 'user', - 'database': 'db', - } - ) - }) - - # port matching (also tests that `:` can be part of password) - self.run_testcase({ - 'dsn': 'postgres://user@cde:5433/db?passfile={}'.format( - passfile.name - ), - 'result': ( - [('cde', 5433)], - { - 'password': 'password from pgpass for cde:5433', - 'user': 'user', - 'database': 'db', - } - ) - }) - - # user matching - self.run_testcase({ - 'dsn': 'postgres://testuser@def/db?passfile={}'.format( - passfile.name - ), - 'result': ( - [('def', 5432)], - { - 'password': 'password from pgpass for testuser', - 'user': 'testuser', - 'database': 'db', - } - ) - }) - - # database matching - self.run_testcase({ - 'dsn': 'postgres://user@efg/testdb?passfile={}'.format( - passfile.name - ), - 'result': ( - [('efg', 5432)], - { - 'password': 'password from pgpass for testdb', - 'user': 'user', - 'database': 'testdb', - } - ) - }) - - # test escaping - self.run_testcase({ - 'dsn': 'postgres://{}@fgh/{}?passfile={}'.format( - 'test\\', 'test:db', passfile.name - ), - 'result': ( - [('fgh', 5432)], - { - 'password': 'password from pgpass with escapes', - 'user': 'test\\', - 'database': 'test:db', - } - ) - }) - - finally: - os.unlink(passfile.name) - def test_connect_pgpass_badness_mode(self): # Verify that .pgpass permissions are checked with tempfile.NamedTemporaryFile('w+t') as passfile: @@ -1153,13 +594,11 @@ async def test_connection_connect_timeout(self): # The backlog on macOS is different from Linux server.listen(0) host, port = server.getsockname() - conn_spec = { - 'host': host, - 'port': port, - 'user': 'foo', - 'server_settings': {}, - 'connect_timeout': 2, - } + conn_spec = pgconnparams.ConnectionParams( + hosts=[(host, port)], + user='foo', + connect_timeout=2, + ) async def placeholder(): async with asyncio.timeout(2): @@ -1184,8 +623,8 @@ async def placeholder(): async with asyncio.timeout(4): # failsafe await pgcon.connect( conn_spec, - 'foo', - pg_params.get_default_runtime_params(), + source_description="test_connection_connect_timeout", + backend_params=pg_params.get_default_runtime_params(), ) finally: @@ -1234,21 +673,17 @@ def check(): await self.con.restore(None, b'', {}) async def test_connection_ssl_to_no_ssl_server(self): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_verify_locations(SSL_CA_CERT_FILE) - with self.assertRaisesRegex(ConnectionError, 'rejected SSL'): await self.connect( host='localhost', - sslmode=SSLMode.require, - ssl=ssl_context) + sslmode=SSLMode.require) async def test_connection_sslmode_no_ssl_server(self): async def verify_works(sslmode): con = None try: con = await self.connect( - dsn='postgresql://foo/?sslmode=' + sslmode, + sslmode=SSLMode.parse(sslmode), user='postgres', database='postgres', host='localhost') @@ -1263,7 +698,7 @@ async def verify_fails(sslmode): try: with self.assertRaises(ConnectionError): con = await self.connect( - dsn='postgresql://foo/?sslmode=' + sslmode, + sslmode=SSLMode.parse(sslmode), user='postgres', database='postgres', host='localhost') @@ -1334,24 +769,16 @@ def _add_hba_entry(self): class TestSSLConnection(BaseTestSSLConnection): def _add_hba_entry(self): self.cluster.add_hba_entry( - type='hostssl', address=ipaddress.ip_network('127.0.0.0/24'), - database=self.dbname, user='ssl_user', - auth_method='trust') - - self.cluster.add_hba_entry( - type='hostssl', address=ipaddress.ip_network('::1/128'), + type='hostssl', address="all", database=self.dbname, user='ssl_user', auth_method='trust') async def test_ssl_connection_custom_context(self): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_verify_locations(SSL_CA_CERT_FILE) - con = await self.connect( host='localhost', user='ssl_user', sslmode=SSLMode.require, - ssl=ssl_context) + sslrootcert=SSL_CA_CERT_FILE) try: await self.assertConnected(con) @@ -1364,7 +791,7 @@ async def verify_works(sslmode, *, host='localhost'): con = None try: con = await self.connect( - dsn='postgresql://foo/postgres?sslmode=' + sslmode, + sslmode=SSLMode.parse(sslmode), host=host, user='ssl_user') await self.assertConnected(con) @@ -1381,7 +808,7 @@ async def verify_fails(sslmode, *, host='localhost', exn_type): self.loop.set_exception_handler(lambda *args: None) with self.assertRaises(exn_type, msg=f"{sslmode} {host}"): con = await self.connect( - dsn='postgresql://foo/?sslmode=' + sslmode, + sslmode=SSLMode.parse(sslmode), host=host, user='ssl_user') await self.assertConnected(con) @@ -1429,7 +856,8 @@ async def test_ssl_connection_default_context(self): host='localhost', user='ssl_user', sslmode=SSLMode.verify_full, - ssl=ssl.create_default_context() + # This won't validate + sslrootcert=CLIENT_CA_CERT_FILE ) finally: self.loop.set_exception_handler(old_handler) @@ -1487,11 +915,6 @@ def _add_hba_entry(self): auth_method='cert') async def test_ssl_connection_client_auth_fails_with_wrong_setup(self): - ssl_context = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH, - cafile=SSL_CA_CERT_FILE, - ) - with self.assertRaisesRegex( errors.BackendError, "requires a valid client certificate", @@ -1500,7 +923,7 @@ async def test_ssl_connection_client_auth_fails_with_wrong_setup(self): host='localhost', user='ssl_user', sslmode=SSLMode.require, - ssl=ssl_context, + sslrootcert=SSL_CA_CERT_FILE, ) async def _test_works(self, **conn_args): @@ -1513,20 +936,14 @@ async def _test_works(self, **conn_args): async def test_ssl_connection_client_auth_custom_context(self): for key_file in (CLIENT_SSL_KEY_FILE, CLIENT_SSL_PROTECTED_KEY_FILE): - ssl_context = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH, - cafile=SSL_CA_CERT_FILE, - ) - ssl_context.load_cert_chain( - CLIENT_SSL_CERT_FILE, - keyfile=key_file, - password='secRet', - ) await self._test_works( host='localhost', user='ssl_user', sslmode=SSLMode.require, - ssl=ssl_context, + sslcert=CLIENT_SSL_CERT_FILE, + sslrootcert=SSL_CA_CERT_FILE, + sslpassword='secRet', + sslkey=key_file ) async def test_ssl_connection_client_auth_dsn(self): @@ -1585,7 +1002,7 @@ async def verify_works(sslmode, *, host='localhost'): con = None try: con = await self.connect( - dsn='postgresql://foo/postgres?sslmode=' + sslmode, + sslmode=SSLMode.parse(sslmode), host=host, user='ssl_user') await self.assertConnected(con) @@ -1604,7 +1021,7 @@ async def verify_fails(sslmode, *, host='localhost'): errors.BackendError ) as cm: con = await self.connect( - dsn='postgresql://foo/?sslmode=' + sslmode, + sslmode=SSLMode.parse(sslmode), host=host, user='ssl_user') await self.assertConnected(con) diff --git a/tests/test_server_ops.py b/tests/test_server_ops.py index 80b7cb30ef6b..91003e7a5680 100644 --- a/tests/test_server_ops.py +++ b/tests/test_server_ops.py @@ -43,7 +43,7 @@ from edb import protocol from edb.common import devmode from edb.protocol import protocol as edb_protocol # type: ignore -from edb.server import args, pgcluster, pgconnparams +from edb.server import args, pgcluster from edb.server import cluster as edbcluster from edb.testbase import server as tb @@ -475,11 +475,9 @@ async def test(pgdata_path): with tempfile.TemporaryDirectory() as td: cluster = await pgcluster.get_local_pg_cluster( td, max_connections=actual, log_level='s') - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='template1', - ), + cluster.update_connection_params( + user='postgres', + database='template1', ) self.assertTrue(await cluster.ensure_initialized()) await cluster.start() @@ -509,11 +507,9 @@ async def test(pgdata_path, tenant): with tempfile.TemporaryDirectory() as td: cluster = await pgcluster.get_local_pg_cluster(td, log_level='s') - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='template1', - ), + cluster.update_connection_params( + user='postgres', + database='template1', ) self.assertTrue(await cluster.ensure_initialized()) @@ -546,11 +542,9 @@ async def test(pgdata_path, tenant): with tempfile.TemporaryDirectory() as td: cluster = await pgcluster.get_local_pg_cluster(td, log_level='s') - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='template1', - ), + cluster.update_connection_params( + user='postgres', + database='template1', ) self.assertTrue(await cluster.ensure_initialized()) @@ -598,11 +592,9 @@ async def test(pgdata_path): with tempfile.TemporaryDirectory() as td: cluster = await pgcluster.get_local_pg_cluster(td, log_level='s') - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='template1', - ), + cluster.update_connection_params( + user='postgres', + database='template1', ) self.assertTrue(await cluster.ensure_initialized()) await cluster.start() @@ -633,11 +625,9 @@ async def _test_server_ops_ignore_other_tenants(self, td, user): async def test_server_ops_ignore_other_tenants(self): with tempfile.TemporaryDirectory() as td: cluster = await pgcluster.get_local_pg_cluster(td, log_level='s') - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='template1', - ), + cluster.update_connection_params( + user='postgres', + database='template1', ) self.assertTrue(await cluster.ensure_initialized()) @@ -652,11 +642,9 @@ async def test_server_ops_ignore_other_tenants(self): async def test_server_ops_ignore_other_tenants_single_role(self): with tempfile.TemporaryDirectory() as td: cluster = await pgcluster.get_local_pg_cluster(td, log_level='s') - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='template1', - ), + cluster.update_connection_params( + user='postgres', + database='template1', ) self.assertTrue(await cluster.ensure_initialized()) cluster.add_hba_entry( @@ -666,7 +654,7 @@ async def test_server_ops_ignore_other_tenants_single_role(self): auth_method="trust", ) await cluster.start() - conn = await cluster.connect() + conn = await cluster.connect(source_description="test_server_ops") setup = b"""\ CREATE ROLE single WITH LOGIN CREATEDB; CREATE DATABASE single; @@ -1218,11 +1206,9 @@ async def test(pgdata_path): cluster = await pgcluster.get_local_pg_cluster( td, max_connections=20, log_level='s' ) - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='template1', - ), + cluster.update_connection_params( + user='postgres', + database='template1', ) self.assertTrue(await cluster.ensure_initialized()) await cluster.start() @@ -1276,11 +1262,9 @@ async def _test_server_ops_restore_with_schema_signal(self, sd1, sd2): async def _init_pg_cluster(self, path): cluster = await pgcluster.get_local_pg_cluster(path, log_level='s') - cluster.set_connection_params( - pgconnparams.ConnectionParameters( - user='postgres', - database='template1', - ), + cluster.update_connection_params( + user='postgres', + database='template1', ) self.assertTrue(await cluster.ensure_initialized()) await cluster.start()