diff --git a/Cargo.lock b/Cargo.lock index 7a3d36cb912..081ff2ee5cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,6 +242,20 @@ name = "bytemuck" version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "byteorder" @@ -255,6 +269,16 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +[[package]] +name = "captive_postgres" +version = "0.1.0" +dependencies = [ + "gel_auth", + "openssl", + "socket2", + "tempfile", +] + [[package]] name = "cardinality-estimator" version = "1.0.2" @@ -493,6 +517,17 @@ dependencies = [ "typenum", ] +[[package]] +name = "db_proto" +version = "0.1.0" +dependencies = [ + "derive_more", + "paste", + "pretty_assertions", + "thiserror 1.0.63", + "uuid", +] + [[package]] name = "derive_more" version = "1.0.0" @@ -1560,11 +1595,15 @@ name = "pgrust" version = "0.1.0" dependencies = [ "base64", + "bytemuck", + "captive_postgres", "clap", "clap_derive", + "db_proto", "derive_more", "futures", "gel_auth", + "hex-literal", "hexdump", "libc", "openssl", @@ -1577,8 +1616,6 @@ dependencies = [ "scopeguard", "serde", "serde_derive", - "socket2", - "tempfile", "test-log", "thiserror 1.0.63", "tokio", @@ -2292,9 +2329,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", diff --git a/Cargo.toml b/Cargo.toml index a7ec25a43fd..50a653d221e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,9 @@ members = [ "edb/graphql-rewrite", "edb/server/_rust_native", "rust/auth", + "rust/captive_postgres", "rust/conn_pool", + "rust/db_proto", "rust/pgrust", "rust/http", "rust/pyo3_util" @@ -20,6 +22,8 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter"] } gel_auth = { path = "rust/auth" } +db_proto = { path = "rust/db_proto" } +captive_postgres = { path = "rust/captive_postgres" } conn_pool = { path = "rust/conn_pool" } pgrust = { path = "rust/pgrust" } http = { path = "rust/http" } diff --git a/docs/reference/sql_adapter.rst b/docs/reference/sql_adapter.rst index a10960a466e..2cecdf8fa02 100644 --- a/docs/reference/sql_adapter.rst +++ b/docs/reference/sql_adapter.rst @@ -300,54 +300,207 @@ construct is mapped to PostgreSQL schema: - Aliases are not mapped to PostgreSQL schema. -- Globals are mapped to connection settings, prefixed with ``global``. - For example, a ``global default::username: str`` can be set using: - - .. code-block:: sql +.. versionadded:: 6.0 - SET "global default::username" TO 'Tom'``. + - Globals are mapped to connection settings, prefixed with ``global``. + For example, a ``global default::username: str`` can be accessed using: -- Access policies are applied to object type tables when setting - ``apply_access_policies_pg`` is set to ``true``. + .. code-block:: sql -- Mutation rewrites and triggers are applied to all DML commands. + SET "global default::username" TO 'Tom'``; + SHOW "global default::username"; + - Access policies are applied to object type tables when setting + ``apply_access_policies_pg`` is set to ``true``. + + - Mutation rewrites and triggers are applied to all DML commands. DML commands ============ -When using ``INSERT``, ``DELETE`` or ``UPDATE`` on any table, mutation rewrites -and triggers are applied. These commands do not have a straight-forward -translation to EdgeQL DML commands, but instead use the following mapping: +.. versionchanged:: _default + + Data Modification Language commands (``INSERT``, ``UPDATE``, ``DELETE``, ..) + are not supported in EdgeDB <6.0. + +.. versionchanged:: 6.0 + +.. versionadded:: 6.0 -- ``INSERT INTO "Foo"`` object table maps to ``insert Foo``, + When using ``INSERT``, ``DELETE`` or ``UPDATE`` on any table, mutation + rewrites and triggers are applied. These commands do not have a + straight-forward translation to EdgeQL DML commands, but instead use the + following mapping: -- ``INSERT INTO "Foo.keywords"`` link/property table maps to an - ``update Foo { keywords += ... }``, + - ``INSERT INTO "Foo"`` object table maps to ``insert Foo``, -- ``DELETE FROM "Foo"`` object table maps to ``delete Foo``, + - ``INSERT INTO "Foo.keywords"`` link/property table maps to an + ``update Foo { keywords += ... }``, -- ``DELETE FROM "Foo.keywords"`` link property/table maps to - ``update Foo { keywords -= ... }``, + - ``DELETE FROM "Foo"`` object table maps to ``delete Foo``, -- ``UPDATE "Foo"`` object table maps to ``update Foo set { ... }``, + - ``DELETE FROM "Foo.keywords"`` link property/table maps to + ``update Foo { keywords -= ... }``, -- ``UPDATE "Foo.keywords"`` is not supported. + - ``UPDATE "Foo"`` object table maps to ``update Foo set { ... }``, + + - ``UPDATE "Foo.keywords"`` is not supported. Connection settings =================== -SQL adapter supports a limited subset of PostgreSQL connection settings. -There are the following additionally connection settings: +SQL adapter supports most of PostgreSQL connection settings +(for example ``search_path``), in the same manner as plain PostgreSQL: + +.. code-block:: sql + + SET search_path TO my_module; + + SHOW search_path; + + RESET search_path; + +.. versionadded:: 6.0 + + In addition, there are the following EdgeDB-specific settings: + + - settings prefixed with ``"global "`` set the values of globals. + + Because SQL syntax allows only string, integer and float constants in + ``SET`` command, globals of other types such as ``datetime`` cannot be set + this way. + + .. code-block:: sql + + SET "global my_module::hello" TO 'world'; + + Special handling is in place to enable setting: + - ``bool`` types via integers 0 or 1), + - ``uuid`` types via hex-encoded strings. + + .. code-block:: sql + + SET "global my_module::current_user_id" + TO "592c62c6-73dd-4b7b-87ba-46e6d34ec171"; + SET "global my_module::is_admin" TO 1; + + To set globals of other types via SQL, it is recommended to change the + global to use one of the simple types instead, and use appropriate casts + where the global is used. + + + - ``allow_user_specified_id`` (default ``false``), + + - ``apply_access_policies_pg`` (default ``false``), + + Note that if ``allow_user_specified_id`` or ``apply_access_policies_pg`` are + unset, they default to configuration set by ``configure current database`` + EdgeQL command. + + +Introspection +============= + +The adapter emulates introspection schemas of PostgreSQL: ``information_schema`` +and ``pg_catalog``. + +Both schemas are not perfectly emulated, since they are quite large and +complicated stores of information, that also changed between versions of +PostgreSQL. + +Because of that, some tools might show objects that are not queryable or might +report problems when introspecting. In such cases, please report the problem on +GitHub so we can track the incompatibility down. + +Note that since the two information schemas are emulated, querying them may +perform worse compared to other tables in the database. As a result, tools like +``pg_dump`` and other introspection utilities might seem slower. + + +Locking +======= + +.. versionchanged:: _default + + SQL adapter does not support ``LOCK`` in EdgeDB <6.0. + +.. versionchanged:: 6.0 + +.. versionadded:: 6.0 + + SQL adapter supports LOCK command with the following limitations: + + - it cannot be used on tables that represent object types with access + properties or links of such objects, + - it cannot be used on tables that represent object types that have child + types extending them. + +Query cache +=========== + +An SQL query is issued to EdgeDB, it is compiled to an internal SQL query, which +is then issued to the backing PostgreSQL instance. The compiled query is then +cached, so each following issue of the same query will not perform any +compilation, but just pass through the cached query. + +.. versionadded:: 6.0 + + Additionally, most queries are "normalized" before compilation. This process + extracts constant values and replaces them by internal query parameters. + This allows sharing of compilation cache between queries that differ in + only constant values. This process is totally opaque and is fully handled by + EdgeDB. For example: + + .. code-block:: sql + + SELECT $1, 42; + + ... is normalized to: + + .. code-block:: sql + + SELECT $1, $2; + + This way, when a similar query is issued to EdgeDB: + + .. code-block:: sql + + SELECT $1, 500; + + ... it normalizes to the same query as before, so it can reuse the query + cache. + + Note that normalization process does not (yet) remove any whitespace, so + queries ``SELECT 1;`` and ``SELECT 1 ;`` are compiled separately. + + +Known limitations +================= + +Following SQL statements are not supported: + +- ``CREATE``, ``ALTER``, ``DROP``, + +- ``TRUNCATE``, ``COMMENT``, ``SECURITY LABEL``, ``IMPORT FOREIGN SCHEMA``, + +- ``GRANT``, ``REVOKE``, + +- ``OPEN``, ``FETCH``, ``MOVE``, ``CLOSE``, ``DECLARE``, ``RETURN``, + +- ``CHECKPOINT``, ``DISCARD``, ``CALL``, + +- ``REINDEX``, ``VACUUM``, ``CLUSTER``, ``REFRESH MATERIALIZED VIEW``, + +- ``LISTEN``, ``UNLISTEN``, ``NOTIFY``, + +- ``LOAD``. -- ``allow_user_specified_id`` (default ``false``), -- ``apply_access_policies_pg`` (default ``false``), -- settings prefixed with ``"global "`` can use used to set values of globals. +Following functions are not supported: -Note that if ``allow_user_specified_id`` or ``apply_access_policies_pg`` are -unset, they default to configuration set by ``configure current database`` -EdgeQL command. +- ``set_config``, +- ``pg_filenode_relation``, +- most of system administration functions. Example: gradual transition from ORMs to EdgeDB diff --git a/edb/buildmeta.py b/edb/buildmeta.py index 31a13a15ded..dd944f10e9e 100644 --- a/edb/buildmeta.py +++ b/edb/buildmeta.py @@ -60,7 +60,7 @@ # The merge conflict there is a nice reminder that you probably need # to write a patch in edb/pgsql/patches.py, and then you should preserve # the old value. -EDGEDB_CATALOG_VERSION = 2024_12_17_00_00 +EDGEDB_CATALOG_VERSION = 2024_01_02_00_00 EDGEDB_MAJOR_VERSION = 7 diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 035b13e71ee..98d4f6a7f71 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -6404,12 +6404,14 @@ def make_wrapper_view(name: str) -> trampoline.VersionedView: name=('edgedbsql', 'uuid_to_oid'), args=( ('id', 'uuid'), + # extra is two extra bits to throw into the oid, for now + ('extra', 'int4', '0'), ), returns=('oid',), volatility='immutable', text=""" SELECT ( - ('x' || substring(id::text, 2, 7))::bit(28)::bigint + ('x' || substring(id::text, 2, 7))::bit(28)::bigint*4 + extra + 40000)::oid; """ ) @@ -7209,7 +7211,9 @@ def make_wrapper_view(name: str) -> trampoline.VersionedView: -- foreign keys for object tables SELECT - edgedbsql_VER.uuid_to_oid(sl.id) as oid, + -- uuid_to_oid needs "extra" arg to disambiguate from the link table + -- keys below + edgedbsql_VER.uuid_to_oid(sl.id, 0) as oid, vt.table_name || '_fk_' || sl.name AS conname, edgedbsql_VER.uuid_to_oid(vt.module_id) AS connamespace, 'f'::"char" AS contype, @@ -7255,7 +7259,9 @@ def make_wrapper_view(name: str) -> trampoline.VersionedView: -- - single link with link properties (source & target), -- these constraints do not actually exist, so we emulate it entierly SELECT - edgedbsql_VER.uuid_to_oid(sp.id) AS oid, + -- uuid_to_oid needs "extra" arg to disambiguate from other + -- constraints using this pointer + edgedbsql_VER.uuid_to_oid(sp.id, spec.attnum) AS oid, vt.table_name || '_fk_' || spec.name AS conname, edgedbsql_VER.uuid_to_oid(vt.module_id) AS connamespace, 'f'::"char" AS contype, @@ -7871,6 +7877,10 @@ def construct_pg_view( returns=('text',), volatility='stable', text=r""" + -- Wrap in a subquery SELECT so that we get a clear failure + -- if something is broken and this returns multiple rows. + -- (By default it would silently return the first.) + SELECT ( SELECT CASE WHEN contype = 'p' THEN 'PRIMARY KEY(' || ( @@ -7883,7 +7893,6 @@ def construct_pg_view( SELECT attname FROM edgedbsql_VER.pg_attribute WHERE attrelid = conrelid AND attnum = ANY(conkey) - LIMIT 1 ) || '")' || ' REFERENCES "' || pn.nspname || '"."' || pc.relname || '"(id)' ELSE '' @@ -7893,6 +7902,7 @@ def construct_pg_view( LEFT JOIN edgedbsql_VER.pg_namespace pn ON pc.relnamespace = pn.oid WHERE con.oid = conid + ) """ ), trampoline.VersionedFunction( diff --git a/edb/server/cluster.py b/edb/server/cluster.py index 6b9a55e1a31..6f4524bb06f 100644 --- a/edb/server/cluster.py +++ b/edb/server/cluster.py @@ -324,7 +324,11 @@ async def test() -> None: started = time.monotonic() await test() left -= (time.monotonic() - started) - if res := self._admin_query("SELECT ();", f"{max(1, int(left))}s"): + if res := self._admin_query( + "SELECT ();", + f"{max(1, int(left))}s", + check=False, + ): raise ClusterError( f'could not connect to edgedb-server ' f'within {timeout} seconds (exit code = {res})') from None @@ -333,6 +337,7 @@ def _admin_query( self, query: str, wait_until_available: str = "0s", + check: bool=True, ) -> int: args = [ "edgedb", @@ -350,12 +355,13 @@ def _admin_query( wait_until_available, query, ] - res = subprocess.call( + res = subprocess.run( args=args, - stdout=subprocess.DEVNULL, + check=check, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) - return res + return res.returncode async def set_test_config(self) -> None: self._admin_query(f''' diff --git a/edb/server/inplace_upgrade.py b/edb/server/inplace_upgrade.py index d07ca92bee6..86b9c4551b2 100644 --- a/edb/server/inplace_upgrade.py +++ b/edb/server/inplace_upgrade.py @@ -57,6 +57,7 @@ from edb.pgsql import common as pg_common from edb.pgsql import dbops +from edb.pgsql import metaschema from edb.pgsql import trampoline @@ -273,6 +274,15 @@ async def _upgrade_one( except Exception: raise + # Refresh the pg_catalog materialized views + current_block = dbops.PLTopBlock() + refresh = metaschema.generate_sql_information_schema_refresh( + backend_params.instance_params.version + ) + refresh.generate(current_block) + patch = current_block.to_string() + await ctx.conn.sql_execute(patch.encode('utf-8')) + new_local_spec = config.load_spec_from_schema( schema, only_exts=True, diff --git a/edb/server/pgcon/pgcon.pxd b/edb/server/pgcon/pgcon.pxd index f1260f07bc2..6b9cf16b0c1 100644 --- a/edb/server/pgcon/pgcon.pxd +++ b/edb/server/pgcon/pgcon.pxd @@ -191,4 +191,4 @@ cdef class PGConnection: cdef inline str get_tenant_label(self) cpdef set_stmt_cache_size(self, int maxsize) -cdef setting_to_sql(setting) +cdef setting_to_sql(name, setting) diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 241e7be2b67..43201ab7ed9 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -1824,7 +1824,9 @@ cdef class PGConnection: msg_buf = WriteBuffer.new_message(b'D') msg_buf.write_int16(1) # number of column values setting = dbv.current_fe_settings()[setting_name] - msg_buf.write_len_prefixed_utf8(setting_to_sql(setting)) + msg_buf.write_len_prefixed_utf8( + setting_to_sql(setting_name, setting) + ) buf.write_buffer(msg_buf.end_message()) # CommandComplete @@ -3015,28 +3017,172 @@ cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message()) cdef EdegDBCodecContext DEFAULT_CODEC_CONTEXT = EdegDBCodecContext() +# Settings that are enums or bools and should not be quoted. +# Can be retrived from PostgreSQL with: +# SELECt name FROM pg_catalog.pg_settings WHERE vartype IN ('enum', 'bool'); +cdef set ENUM_SETTINGS = { + 'allow_alter_system', + 'allow_in_place_tablespaces', + 'allow_system_table_mods', + 'archive_mode', + 'array_nulls', + 'autovacuum', + 'backslash_quote', + 'bytea_output', + 'check_function_bodies', + 'client_min_messages', + 'compute_query_id', + 'constraint_exclusion', + 'data_checksums', + 'data_sync_retry', + 'debug_assertions', + 'debug_logical_replication_streaming', + 'debug_parallel_query', + 'debug_pretty_print', + 'debug_print_parse', + 'debug_print_plan', + 'debug_print_rewritten', + 'default_toast_compression', + 'default_transaction_deferrable', + 'default_transaction_isolation', + 'default_transaction_read_only', + 'dynamic_shared_memory_type', + 'edb_stat_statements.save', + 'edb_stat_statements.track', + 'edb_stat_statements.track_planning', + 'edb_stat_statements.track_utility', + 'enable_async_append', + 'enable_bitmapscan', + 'enable_gathermerge', + 'enable_group_by_reordering', + 'enable_hashagg', + 'enable_hashjoin', + 'enable_incremental_sort', + 'enable_indexonlyscan', + 'enable_indexscan', + 'enable_material', + 'enable_memoize', + 'enable_mergejoin', + 'enable_nestloop', + 'enable_parallel_append', + 'enable_parallel_hash', + 'enable_partition_pruning', + 'enable_partitionwise_aggregate', + 'enable_partitionwise_join', + 'enable_presorted_aggregate', + 'enable_seqscan', + 'enable_sort', + 'enable_tidscan', + 'escape_string_warning', + 'event_triggers', + 'exit_on_error', + 'fsync', + 'full_page_writes', + 'geqo', + 'gss_accept_delegation', + 'hot_standby', + 'hot_standby_feedback', + 'huge_pages', + 'huge_pages_status', + 'icu_validation_level', + 'ignore_checksum_failure', + 'ignore_invalid_pages', + 'ignore_system_indexes', + 'in_hot_standby', + 'integer_datetimes', + 'intervalstyle', + 'jit', + 'jit_debugging_support', + 'jit_dump_bitcode', + 'jit_expressions', + 'jit_profiling_support', + 'jit_tuple_deforming', + 'krb_caseins_users', + 'lo_compat_privileges', + 'log_checkpoints', + 'log_connections', + 'log_disconnections', + 'log_duration', + 'log_error_verbosity', + 'log_executor_stats', + 'log_hostname', + 'log_lock_waits', + 'log_min_error_statement', + 'log_min_messages', + 'log_parser_stats', + 'log_planner_stats', + 'log_recovery_conflict_waits', + 'log_replication_commands', + 'log_statement', + 'log_statement_stats', + 'log_truncate_on_rotation', + 'logging_collector', + 'parallel_leader_participation', + 'password_encryption', + 'plan_cache_mode', + 'quote_all_identifiers', + 'recovery_init_sync_method', + 'recovery_prefetch', + 'recovery_target_action', + 'recovery_target_inclusive', + 'remove_temp_files_after_crash', + 'restart_after_crash', + 'row_security', + 'send_abort_for_crash', + 'send_abort_for_kill', + 'session_replication_role', + 'shared_memory_type', + 'ssl', + 'ssl_max_protocol_version', + 'ssl_min_protocol_version', + 'ssl_passphrase_command_supports_reload', + 'ssl_prefer_server_ciphers', + 'standard_conforming_strings', + 'stats_fetch_consistency', + 'summarize_wal', + 'sync_replication_slots', + 'synchronize_seqscans', + 'synchronous_commit', + 'syslog_facility', + 'syslog_sequence_numbers', + 'syslog_split_messages', + 'trace_connection_negotiation', + 'trace_notify', + 'trace_sort', + 'track_activities', + 'track_commit_timestamp', + 'track_counts', + 'track_functions', + 'track_io_timing', + 'track_wal_io_timing', + 'transaction_deferrable', + 'transaction_isolation', + 'transaction_read_only', + 'transform_null_equals', + 'update_process_title', + 'wal_compression', + 'wal_init_zero', + 'wal_level', + 'wal_log_hints', + 'wal_receiver_create_temp_slot', + 'wal_recycle', + 'wal_sync_method', + 'xmlbinary', + 'xmloption', + 'zero_damaged_pages', +} + + +cdef setting_to_sql(name, setting): + is_enum = name.lower() in ENUM_SETTINGS -cdef setting_to_sql(setting): assert typeutils.is_container(setting) - return ', '.join(setting_val_to_sql(v) for v in setting) - - -cdef set NON_QUOTABLE_STRINGS = { - 'repeatable read', - 'read committed', - 'read uncommitted', - 'off', - 'on', - 'yes', - 'no', - 'true', - 'false', -} + return ', '.join(setting_val_to_sql(v, is_enum) for v in setting) -cdef inline str setting_val_to_sql(val: str | int | float): +cdef inline str setting_val_to_sql(val: str | int | float, is_enum: bool): if isinstance(val, str): - if val in NON_QUOTABLE_STRINGS: + if is_enum: # special case: no quoting return val # quote as identifier diff --git a/edb/server/protocol/pg_ext.pyx b/edb/server/protocol/pg_ext.pyx index cf5a040359b..59e699c36fd 100644 --- a/edb/server/protocol/pg_ext.pyx +++ b/edb/server/protocol/pg_ext.pyx @@ -355,7 +355,7 @@ cdef class ConnectionView: return self._session_state_db_cache[1] rv = json.dumps({ - key: setting_to_sql(val) for key, val in self._settings.items() + key: setting_to_sql(key, val) for key, val in self._settings.items() }).encode("utf-8") self._session_state_db_cache = (self._settings, rv) return rv diff --git a/rust/auth/src/scram.rs b/rust/auth/src/scram.rs index bcf055d9aae..8c1e8ace725 100644 --- a/rust/auth/src/scram.rs +++ b/rust/auth/src/scram.rs @@ -6,7 +6,7 @@ //! protocols like Postgres and SASL to enhance security against common attacks //! such as replay and man-in-the-middle attacks. //! -//! https://en.wikipedia.org/wiki/Salted_Challenge_Response_Authentication_Mechanism +//! //! //! ## Limitations of this implementation //! diff --git a/rust/captive_postgres/Cargo.toml b/rust/captive_postgres/Cargo.toml new file mode 100644 index 00000000000..995a4b6d482 --- /dev/null +++ b/rust/captive_postgres/Cargo.toml @@ -0,0 +1,17 @@ + +[package] +name = "captive_postgres" +version = "0.1.0" +license = "MIT/Apache-2.0" +authors = ["MagicStack Inc. "] +edition = "2021" + +[lints] +workspace = true + +[dependencies] +gel_auth.workspace = true + +openssl = "0.10.55" +tempfile = "3" +socket2 = "0.5.8" diff --git a/rust/captive_postgres/README.md b/rust/captive_postgres/README.md new file mode 100644 index 00000000000..36c9d67cd65 --- /dev/null +++ b/rust/captive_postgres/README.md @@ -0,0 +1,5 @@ +# captive_postgres + +A simple, captive Postgres server that can be used to test client connections. Each instance +is a freshly initialized Postgres server with the specified credentials. + diff --git a/rust/captive_postgres/src/lib.rs b/rust/captive_postgres/src/lib.rs new file mode 100644 index 00000000000..03be2d2ff53 --- /dev/null +++ b/rust/captive_postgres/src/lib.rs @@ -0,0 +1,384 @@ +// Constants +use gel_auth::AuthType; +use openssl::ssl::{Ssl, SslContext, SslMethod}; +use std::io::{BufRead, BufReader, Write}; +use std::net::{Ipv4Addr, SocketAddr, TcpListener}; +use std::os::unix::fs::PermissionsExt; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::sync::{Arc, RwLock}; +use std::thread; +use std::time::{Duration, Instant}; +use tempfile::TempDir; + +pub const STARTUP_TIMEOUT_DURATION: Duration = Duration::from_secs(30); +pub const PORT_RELEASE_TIMEOUT: Duration = Duration::from_secs(30); +pub const LINGER_DURATION: Duration = Duration::from_secs(1); +pub const HOT_LOOP_INTERVAL: Duration = Duration::from_millis(100); +pub const DEFAULT_USERNAME: &str = "username"; +pub const DEFAULT_PASSWORD: &str = "password"; +pub const DEFAULT_DATABASE: &str = "postgres"; + +#[derive(Debug, Clone)] +pub enum ListenAddress { + Tcp(SocketAddr), + #[cfg(unix)] + Unix(PathBuf), +} + +/// Represents an ephemeral port that can be allocated and released for immediate re-use by another process. +struct EphemeralPort { + port: u16, + listener: Option, +} + +impl EphemeralPort { + /// Allocates a new ephemeral port. + /// + /// Returns a Result containing the EphemeralPort if successful, + /// or an IO error if the allocation fails. + fn allocate() -> std::io::Result { + let socket = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.set_linger(Some(LINGER_DURATION))?; + socket.bind(&std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, 0)).into())?; + socket.listen(1)?; + let listener = TcpListener::from(socket); + let port = listener.local_addr()?.port(); + Ok(EphemeralPort { + port, + listener: Some(listener), + }) + } + + /// Consumes the EphemeralPort and returns the allocated port number. + fn take(self) -> u16 { + // Drop the listener to free up the port + drop(self.listener); + + // Loop until the port is free + let start = Instant::now(); + + // If we can successfully connect to the port, it's not fully closed + while start.elapsed() < PORT_RELEASE_TIMEOUT { + let res = std::net::TcpStream::connect((Ipv4Addr::LOCALHOST, self.port)); + if res.is_err() { + // If connection fails, the port is released + break; + } + std::thread::sleep(HOT_LOOP_INTERVAL); + } + + self.port + } +} + +struct StdioReader { + output: Arc>, +} + +impl StdioReader { + fn spawn(reader: R, prefix: &'static str) -> Self { + let output = Arc::new(RwLock::new(String::new())); + let output_clone = Arc::clone(&output); + + thread::spawn(move || { + let mut buf_reader = std::io::BufReader::new(reader); + loop { + let mut line = String::new(); + match buf_reader.read_line(&mut line) { + Ok(0) => break, + Ok(_) => { + if let Ok(mut output) = output_clone.write() { + output.push_str(&line); + } + eprint!("[{}]: {}", prefix, line); + } + Err(e) => { + let error_line = format!("Error reading {}: {}\n", prefix, e); + if let Ok(mut output) = output_clone.write() { + output.push_str(&error_line); + } + eprintln!("{}", error_line); + } + } + } + }); + + StdioReader { output } + } + + fn contains(&self, s: &str) -> bool { + if let Ok(output) = self.output.read() { + output.contains(s) + } else { + false + } + } +} + +fn init_postgres(initdb: &Path, data_dir: &Path, auth: AuthType) -> std::io::Result<()> { + let mut pwfile = tempfile::NamedTempFile::new()?; + writeln!(pwfile, "{}", DEFAULT_PASSWORD)?; + let mut command = Command::new(initdb); + command + .arg("-D") + .arg(data_dir) + .arg("-A") + .arg(match auth { + AuthType::Deny => "reject", + AuthType::Trust => "trust", + AuthType::Plain => "password", + AuthType::Md5 => "md5", + AuthType::ScramSha256 => "scram-sha-256", + }) + .arg("--pwfile") + .arg(pwfile.path()) + .arg("-U") + .arg(DEFAULT_USERNAME); + + let output = command.output()?; + + let status = output.status; + let output_str = String::from_utf8_lossy(&output.stdout).to_string(); + let error_str = String::from_utf8_lossy(&output.stderr).to_string(); + + eprintln!("initdb stdout:\n{}", output_str); + eprintln!("initdb stderr:\n{}", error_str); + + if !status.success() { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "initdb command failed", + )); + } + + Ok(()) +} + +fn run_postgres( + postgres_bin: &Path, + data_dir: &Path, + socket_path: &Path, + ssl: Option<(PathBuf, PathBuf)>, + port: u16, +) -> std::io::Result { + let mut command = Command::new(postgres_bin); + command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .arg("-D") + .arg(data_dir) + .arg("-k") + .arg(socket_path) + .arg("-h") + .arg(Ipv4Addr::LOCALHOST.to_string()) + .arg("-F") + // Useful for debugging + // .arg("-d") + // .arg("5") + .arg("-p") + .arg(port.to_string()); + + if let Some((cert_path, key_path)) = ssl { + let postgres_cert_path = data_dir.join("server.crt"); + let postgres_key_path = data_dir.join("server.key"); + std::fs::copy(cert_path, &postgres_cert_path)?; + std::fs::copy(key_path, &postgres_key_path)?; + // Set permissions for the certificate and key files + std::fs::set_permissions(&postgres_cert_path, std::fs::Permissions::from_mode(0o600))?; + std::fs::set_permissions(&postgres_key_path, std::fs::Permissions::from_mode(0o600))?; + + // Edit pg_hba.conf to change all "host" line prefixes to "hostssl" + let pg_hba_path = data_dir.join("pg_hba.conf"); + let content = std::fs::read_to_string(&pg_hba_path)?; + let modified_content = content + .lines() + .filter(|line| !line.starts_with("#") && !line.is_empty()) + .map(|line| { + if line.trim_start().starts_with("host") { + line.replacen("host", "hostssl", 1) + } else { + line.to_string() + } + }) + .collect::>() + .join("\n"); + eprintln!("pg_hba.conf:\n==========\n{modified_content}\n=========="); + std::fs::write(&pg_hba_path, modified_content)?; + + command.arg("-l"); + } + + let mut child = command.spawn()?; + + let stdout_reader = BufReader::new(child.stdout.take().expect("Failed to capture stdout")); + let _ = StdioReader::spawn(stdout_reader, "stdout"); + let stderr_reader = BufReader::new(child.stderr.take().expect("Failed to capture stderr")); + let stderr_reader = StdioReader::spawn(stderr_reader, "stderr"); + + let start_time = Instant::now(); + + let mut tcp_socket: Option = None; + let mut unix_socket: Option = None; + + let unix_socket_path = get_unix_socket_path(socket_path, port); + let tcp_socket_addr = std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, port)); + let mut db_ready = false; + + while start_time.elapsed() < STARTUP_TIMEOUT_DURATION { + std::thread::sleep(HOT_LOOP_INTERVAL); + match child.try_wait() { + Ok(Some(status)) => { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("PostgreSQL exited with status: {}", status), + )) + } + Err(e) => return Err(e), + _ => {} + } + if !db_ready && stderr_reader.contains("database system is ready to accept connections") { + eprintln!("Database is ready"); + db_ready = true; + } else { + continue; + } + if unix_socket.is_none() { + unix_socket = std::os::unix::net::UnixStream::connect(&unix_socket_path).ok(); + } + if tcp_socket.is_none() { + tcp_socket = std::net::TcpStream::connect(tcp_socket_addr).ok(); + } + if unix_socket.is_some() && tcp_socket.is_some() { + break; + } + } + + if unix_socket.is_some() && tcp_socket.is_some() { + return Ok(child); + } + + // Print status for TCP/unix sockets + if let Some(tcp) = &tcp_socket { + eprintln!( + "TCP socket at {tcp_socket_addr:?} bound successfully on {}", + tcp.local_addr()? + ); + } else { + eprintln!("TCP socket at {tcp_socket_addr:?} binding failed"); + } + + if unix_socket.is_some() { + eprintln!("Unix socket at {unix_socket_path:?} connected successfully"); + } else { + eprintln!("Unix socket at {unix_socket_path:?} connection failed"); + } + + Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "PostgreSQL failed to start within 30 seconds", + )) +} + +fn test_data_dir() -> std::path::PathBuf { + let cargo_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../tests"); + if cargo_path.exists() { + cargo_path + } else { + Path::new("../../tests") + .canonicalize() + .expect("Failed to canonicalize tests directory path") + } +} + +fn postgres_bin_dir() -> std::io::Result { + let cargo_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../build/postgres/install/bin"); + if cargo_path.exists() { + cargo_path.canonicalize() + } else { + Path::new("../../build/postgres/install/bin").canonicalize() + } +} + +fn get_unix_socket_path(socket_path: &Path, port: u16) -> PathBuf { + socket_path.join(format!(".s.PGSQL.{}", port)) +} + +#[derive(Debug, Clone, Copy)] +pub enum Mode { + Tcp, + TcpSsl, + Unix, +} + +pub fn create_ssl_client() -> Result> { + let ssl_context = SslContext::builder(SslMethod::tls_client())?.build(); + let mut ssl = Ssl::new(&ssl_context)?; + ssl.set_connect_state(); + Ok(ssl) +} + +pub struct PostgresProcess { + child: std::process::Child, + pub socket_address: ListenAddress, + #[allow(unused)] + temp_dir: TempDir, +} + +impl Drop for PostgresProcess { + fn drop(&mut self) { + let _ = self.child.kill(); + } +} + +/// Creates and runs a new Postgres server process in a temporary directory. +pub fn setup_postgres( + auth: AuthType, + mode: Mode, +) -> Result, Box> { + let Ok(bindir) = postgres_bin_dir() else { + println!("Skipping test: postgres bin dir not found"); + return Ok(None); + }; + + let initdb = bindir.join("initdb"); + let postgres = bindir.join("postgres"); + + if !initdb.exists() || !postgres.exists() { + println!("Skipping test: initdb or postgres not found"); + return Ok(None); + } + + let temp_dir = TempDir::new()?; + let port = EphemeralPort::allocate()?; + let data_dir = temp_dir.path().join("data"); + + init_postgres(&initdb, &data_dir, auth)?; + let ssl_key = match mode { + Mode::TcpSsl => { + let certs_dir = test_data_dir().join("certs"); + let cert = certs_dir.join("server.cert.pem"); + let key = certs_dir.join("server.key.pem"); + Some((cert, key)) + } + _ => None, + }; + + let port = port.take(); + let child = run_postgres(&postgres, &data_dir, &data_dir, ssl_key, port)?; + + let socket_address = match mode { + Mode::Unix => ListenAddress::Unix(get_unix_socket_path(&data_dir, port)), + Mode::Tcp | Mode::TcpSsl => { + ListenAddress::Tcp(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port)) + } + }; + + Ok(Some(PostgresProcess { + child, + socket_address, + temp_dir, + })) +} diff --git a/rust/db_proto/Cargo.toml b/rust/db_proto/Cargo.toml new file mode 100644 index 00000000000..80d8bca57f0 --- /dev/null +++ b/rust/db_proto/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "db_proto" +version = "0.1.0" +license = "MIT/Apache-2.0" +authors = ["MagicStack Inc. "] +edition = "2021" + +[lints] +workspace = true + +[dependencies] +thiserror = "1" +paste = "1" +derive_more = { version = "1", features = ["full"] } +uuid = "1" + +[dev-dependencies] +pretty_assertions = "1.2.0" diff --git a/rust/db_proto/README.md b/rust/db_proto/README.md new file mode 100644 index 00000000000..aae37968447 --- /dev/null +++ b/rust/db_proto/README.md @@ -0,0 +1,4 @@ +# db_proto + +This is a crate that makes parsing and serializing of PostgreSQL-like protocols +(ie: Postgres itself, as well as Gel/EdgeDB) easier. diff --git a/rust/pgrust/src/protocol/arrays.rs b/rust/db_proto/src/arrays.rs similarity index 53% rename from rust/pgrust/src/protocol/arrays.rs rename to rust/db_proto/src/arrays.rs index 9e7811ca049..c65a5118f82 100644 --- a/rust/pgrust/src/protocol/arrays.rs +++ b/rust/db_proto/src/arrays.rs @@ -1,9 +1,11 @@ #![allow(private_bounds)] -use super::{Enliven, FieldAccessArray, FixedSize, Meta, MetaRelation}; + +use super::{Enliven, FieldAccessArray, FixedSize, Meta, MetaRelation, ParseError}; pub use std::marker::PhantomData; pub mod meta { pub use super::ArrayMeta as Array; + pub use super::FixedArrayMeta as FixedArray; pub use super::ZTArrayMeta as ZTArray; } @@ -15,7 +17,7 @@ pub struct ZTArray<'a, T: FieldAccessArray> { /// Metaclass for [`ZTArray`]. pub struct ZTArrayMeta { - pub(crate) _phantom: PhantomData, + pub _phantom: PhantomData, } impl Meta for ZTArrayMeta { @@ -95,6 +97,38 @@ impl<'a, T: FieldAccessArray> Iterator for ZTArrayIter<'a, T> { } } +impl FieldAccessArray for ZTArrayMeta { + const META: &'static dyn Meta = &ZTArrayMeta:: { + _phantom: PhantomData, + }; + fn size_of_field_at(mut buf: &[u8]) -> Result { + let mut size = 1; + loop { + if buf.is_empty() { + return Err(ParseError::TooShort); + } + if buf[0] == 0 { + return Ok(size); + } + let elem_size = match T::size_of_field_at(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + buf = buf.split_at(elem_size).1; + size += elem_size; + } + } + fn extract(buf: &[u8]) -> Result, ParseError> { + Ok(ZTArray::new(buf)) + } + fn copy_to_buf(buf: &mut crate::BufWriter, value: &&[::ForBuilder<'_>]) { + for elem in *value { + T::copy_to_buf(buf, elem); + } + buf.write_u8(0); + } +} + /// Inflated version of a length-specified array with zero-copy iterator access. pub struct Array<'a, L, T: FieldAccessArray> { _phantom: PhantomData<(L, T)>, @@ -104,7 +138,7 @@ pub struct Array<'a, L, T: FieldAccessArray> { /// Metaclass for [`Array`]. pub struct ArrayMeta { - pub(crate) _phantom: PhantomData<(L, T)>, + pub _phantom: PhantomData<(L, T)>, } impl Meta for ArrayMeta { @@ -121,7 +155,7 @@ impl Meta for ArrayMeta { impl Enliven for ArrayMeta where - T: FieldAccessArray, + T: FieldAccessArray + Enliven, { type WithLifetime<'a> = Array<'a, L, T>; type ForMeasure<'a> = &'a [::ForMeasure<'a>]; @@ -203,130 +237,6 @@ impl<'a, T: FieldAccessArray> Iterator for ArrayIter<'a, T> { } } -/// Definate array accesses for inflated, strongly-typed arrays of both -/// zero-terminated and length-delimited types. -macro_rules! array_access { - ($ty:ty) => { - $crate::protocol::arrays::array_access!($ty | u8 i16 i32); - }; - ($ty:ty | $($len:ty)*) => { - $( - #[allow(unused)] - impl $crate::protocol::FieldAccess<$crate::protocol::meta::Array<$len, $ty>> { - pub const fn meta() -> &'static dyn $crate::protocol::Meta { - &$crate::protocol::meta::Array::<$len, $ty> { _phantom: std::marker::PhantomData } - } - #[inline] - pub const fn size_of_field_at(mut buf: &[u8]) -> Result { - let mut size = std::mem::size_of::<$len>(); - let mut len = match $crate::protocol::FieldAccess::<$len>::extract(buf) { - Ok(n) => n, - Err(e) => return Err(e), - }; - #[allow(unused_comparisons)] - if len < 0 { - return Err($crate::protocol::ParseError::InvalidData); - } - buf = buf.split_at(size).1; - loop { - if len <= 0 { - break; - } - len -= 1; - let elem_size = match $crate::protocol::FieldAccess::<$ty>::size_of_field_at(buf) { - Ok(n) => n, - Err(e) => return Err(e), - }; - buf = buf.split_at(elem_size).1; - size += elem_size; - } - Ok(size) - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result<$crate::protocol::Array<'_, $len, $ty>, $crate::protocol::ParseError> { - match $crate::protocol::FieldAccess::<$len>::extract(buf) { - Ok(len) => Ok($crate::protocol::Array::new(buf.split_at(std::mem::size_of::<$len>()).1, len as u32)), - Err(e) => Err(e) - } - } - #[inline] - pub const fn measure<'a>(buffer: &'a[<$ty as $crate::protocol::Enliven>::ForMeasure<'a>]) -> usize { - let mut size = std::mem::size_of::<$len>(); - let mut index = 0; - loop { - if index + 1 > buffer.len() { - break; - } - let item = &buffer[index]; - size += $crate::protocol::FieldAccess::<$ty>::measure(item); - index += 1; - } - size - } - #[inline(always)] - pub fn copy_to_buf<'a>(buf: &mut $crate::protocol::writer::BufWriter, value: &'a[<$ty as $crate::protocol::Enliven>::ForBuilder<'a>]) { - buf.write(&<$len>::to_be_bytes(value.len() as _)); - for elem in value { - $crate::protocol::FieldAccess::<$ty>::copy_to_buf_ref(buf, elem); - } - } - - } - )* - - #[allow(unused)] - impl $crate::protocol::FieldAccess<$crate::protocol::meta::ZTArray<$ty>> { - pub const fn meta() -> &'static dyn $crate::protocol::Meta { - &$crate::protocol::meta::ZTArray::<$ty> { _phantom: std::marker::PhantomData } - } - #[inline] - pub const fn size_of_field_at(mut buf: &[u8]) -> Result { - let mut size = 1; - loop { - if buf.is_empty() { - return Err($crate::protocol::ParseError::TooShort); - } - if buf[0] == 0 { - return Ok(size); - } - let elem_size = match $crate::protocol::FieldAccess::<$ty>::size_of_field_at(buf) { - Ok(n) => n, - Err(e) => return Err(e), - }; - buf = buf.split_at(elem_size).1; - size += elem_size; - } - } - #[inline(always)] - pub const fn extract(mut buf: &[u8]) -> Result<$crate::protocol::ZTArray<$ty>, $crate::protocol::ParseError> { - Ok($crate::protocol::ZTArray::new(buf)) - } - #[inline] - pub const fn measure<'a>(buffer: &'a[<$ty as $crate::protocol::Enliven>::ForMeasure<'a>]) -> usize { - let mut size = 1; - let mut index = 0; - loop { - if index + 1 > buffer.len() { - break; - } - let item = &buffer[index]; - size += $crate::protocol::FieldAccess::<$ty>::measure(item); - index += 1; - } - size - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut $crate::protocol::writer::BufWriter, value: &[<$ty as $crate::protocol::Enliven>::ForBuilder<'_>]) { - for elem in value { - $crate::protocol::FieldAccess::<$ty>::copy_to_buf_ref(buf, elem); - } - buf.write_u8(0); - } - } - }; -} -pub(crate) use array_access; - // Arrays of type [`u8`] are special-cased to return a slice of bytes. impl AsRef<[u8]> for Array<'_, T, u8> { fn as_ref(&self) -> &[u8] { @@ -350,3 +260,82 @@ impl<'a, L: TryInto, T: FixedSize + FieldAccessArray> Array<'a, L, T> { } } } + +impl FieldAccessArray + for ArrayMeta +where + for<'a> L::ForBuilder<'a>: TryFrom, + for<'a> L::WithLifetime<'a>: TryInto, +{ + const META: &'static dyn Meta = &ArrayMeta:: { + _phantom: PhantomData, + }; + fn size_of_field_at(mut buf: &[u8]) -> Result { + let mut size = std::mem::size_of::(); + let len = match L::extract(buf) { + Ok(n) => n.try_into(), + Err(e) => return Err(e), + }; + #[allow(unused_comparisons)] + let Ok(mut len) = len + else { + return Err(ParseError::InvalidData); + }; + buf = buf.split_at(size).1; + loop { + if len == 0 { + break; + } + len -= 1; + let elem_size = match T::size_of_field_at(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + buf = buf.split_at(elem_size).1; + size += elem_size; + } + Ok(size) + } + fn extract(buf: &[u8]) -> Result, ParseError> { + let len = match L::extract(buf) { + Ok(len) => len.try_into(), + Err(e) => { + return Err(e); + } + }; + let Ok(len) = len else { + return Err(ParseError::InvalidData); + }; + Ok(Array::new( + buf.split_at(std::mem::size_of::()).1, + len as _, + )) + } + fn copy_to_buf(buf: &mut crate::BufWriter, value: &&[::ForBuilder<'_>]) { + let Ok(len) = L::ForBuilder::try_from(value.len()) else { + panic!("Array length out of bounds"); + }; + L::copy_to_buf(buf, &len); + for elem in *value { + T::copy_to_buf(buf, elem); + } + } +} + +pub struct FixedArrayMeta { + pub _phantom: PhantomData, +} + +impl Meta for FixedArrayMeta { + fn name(&self) -> &'static str { + "FixedArray" + } + + fn fixed_length(&self) -> Option { + Some(S) + } + + fn relations(&self) -> &'static [(MetaRelation, &'static dyn Meta)] { + &[(MetaRelation::Item, ::META)] + } +} diff --git a/rust/pgrust/src/protocol/buffer.rs b/rust/db_proto/src/buffer.rs similarity index 98% rename from rust/pgrust/src/protocol/buffer.rs rename to rust/db_proto/src/buffer.rs index 6837407b41a..462978388c2 100644 --- a/rust/pgrust/src/protocol/buffer.rs +++ b/rust/db_proto/src/buffer.rs @@ -154,9 +154,10 @@ impl StructBuffer { #[cfg(test)] mod tests { + use crate::{Encoded, ParseError}; + use super::StructBuffer; - use crate::protocol::postgres::{builder, data::*, meta}; - use crate::protocol::*; + use crate::test_protocol::{builder, data::*, meta}; /// Create a test data buffer containing three messages fn test_data() -> (Vec, Vec) { diff --git a/rust/db_proto/src/datatypes.rs b/rust/db_proto/src/datatypes.rs new file mode 100644 index 00000000000..a2f289fdd4c --- /dev/null +++ b/rust/db_proto/src/datatypes.rs @@ -0,0 +1,558 @@ +use std::{marker::PhantomData, str::Utf8Error}; + +pub use uuid::Uuid; + +use crate::{ + declare_field_access, declare_field_access_fixed_size, writer::BufWriter, Enliven, FieldAccess, + FieldAccessArray, Meta, ParseError, +}; + +pub mod meta { + pub use super::BasicMeta as Basic; + pub use super::EncodedMeta as Encoded; + pub use super::LStringMeta as LString; + pub use super::LengthMeta as Length; + pub use super::RestMeta as Rest; + pub use super::UuidMeta as Uuid; + pub use super::ZTStringMeta as ZTString; +} + +/// Represents the remainder of data in a message. +#[derive(Debug, PartialEq, Eq)] +pub struct Rest<'a> { + buf: &'a [u8], +} + +declare_field_access! { + Meta = RestMeta, + Inflated = Rest<'a>, + Measure = &'a [u8], + Builder = &'a [u8], + + pub const fn meta() -> &'static dyn Meta { + &RestMeta {} + } + + pub const fn size_of_field_at(buf: &[u8]) -> Result { + Ok(buf.len()) + } + + pub const fn extract(buf: &[u8]) -> Result, ParseError> { + Ok(Rest { buf }) + } + + pub const fn measure(buf: &[u8]) -> usize { + buf.len() + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &[u8]) { + buf.write(value) + } + + pub const fn constant(_constant: usize) -> Rest<'static> { + panic!("Constants unsupported for this data type") + } +} + +pub struct RestMeta {} +impl Meta for RestMeta { + fn name(&self) -> &'static str { + "Rest" + } +} + +impl<'a> Rest<'a> {} + +impl<'a> AsRef<[u8]> for Rest<'a> { + fn as_ref(&self) -> &[u8] { + self.buf + } +} + +impl<'a> std::ops::Deref for Rest<'a> { + type Target = [u8]; + fn deref(&self) -> &Self::Target { + self.buf + } +} + +impl PartialEq<[u8]> for Rest<'_> { + fn eq(&self, other: &[u8]) -> bool { + self.buf == other + } +} + +impl PartialEq<&[u8; N]> for Rest<'_> { + fn eq(&self, other: &&[u8; N]) -> bool { + self.buf == *other + } +} + +impl PartialEq<&[u8]> for Rest<'_> { + fn eq(&self, other: &&[u8]) -> bool { + self.buf == *other + } +} + +/// A zero-terminated string. +#[allow(unused)] +pub struct ZTString<'a> { + buf: &'a [u8], +} + +declare_field_access!( + Meta = ZTStringMeta, + Inflated = ZTString<'a>, + Measure = &'a str, + Builder = &'a str, + + pub const fn meta() -> &'static dyn Meta { + &ZTStringMeta {} + } + + pub const fn size_of_field_at(buf: &[u8]) -> Result { + let mut i = 0; + loop { + if i >= buf.len() { + return Err(ParseError::TooShort); + } + if buf[i] == 0 { + return Ok(i + 1); + } + i += 1; + } + } + + pub const fn extract(buf: &[u8]) -> Result, ParseError> { + let buf = buf.split_at(buf.len() - 1).0; + Ok(ZTString { buf }) + } + + pub const fn measure(buf: &str) -> usize { + buf.len() + 1 + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &str) { + buf.write(value.as_bytes()); + buf.write_u8(0); + } + + pub const fn constant(_constant: usize) -> ZTString<'static> { + panic!("Constants unsupported for this data type") + } +); + +pub struct ZTStringMeta {} +impl Meta for ZTStringMeta { + fn name(&self) -> &'static str { + "ZTString" + } +} + +impl std::fmt::Debug for ZTString<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + String::from_utf8_lossy(self.buf).fmt(f) + } +} + +impl<'a> ZTString<'a> { + pub fn to_owned(&self) -> Result { + std::str::from_utf8(self.buf).map(|s| s.to_owned()) + } + + pub fn to_str(&self) -> Result<&str, std::str::Utf8Error> { + std::str::from_utf8(self.buf) + } + + pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + String::from_utf8_lossy(self.buf) + } + + pub fn to_bytes(&self) -> &[u8] { + self.buf + } +} + +impl PartialEq for ZTString<'_> { + fn eq(&self, other: &Self) -> bool { + self.buf == other.buf + } +} +impl Eq for ZTString<'_> {} + +impl PartialEq for ZTString<'_> { + fn eq(&self, other: &str) -> bool { + self.buf == other.as_bytes() + } +} + +impl PartialEq<&str> for ZTString<'_> { + fn eq(&self, other: &&str) -> bool { + self.buf == other.as_bytes() + } +} + +impl<'a> TryInto<&'a str> for ZTString<'a> { + type Error = Utf8Error; + fn try_into(self) -> Result<&'a str, Self::Error> { + std::str::from_utf8(self.buf) + } +} + +/// A length-prefixed string. +#[allow(unused)] +pub struct LString<'a> { + buf: &'a [u8], +} + +declare_field_access!( + Meta = LStringMeta, + Inflated = LString<'a>, + Measure = &'a str, + Builder = &'a str, + + pub const fn meta() -> &'static dyn Meta { + &LStringMeta {} + } + + pub const fn size_of_field_at(buf: &[u8]) -> Result { + if buf.len() < 4 { + return Err(ParseError::TooShort); + } + let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + Ok(4 + len) + } + + pub const fn extract(buf: &[u8]) -> Result, ParseError> { + if buf.len() < 4 { + return Err(ParseError::TooShort); + } + let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + if buf.len() < 4 + len { + return Err(ParseError::TooShort); + } + Ok(LString { + buf: buf.split_at(4).1, + }) + } + + pub const fn measure(buf: &str) -> usize { + 4 + buf.len() + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &str) { + let len = value.len() as u32; + buf.write(&len.to_be_bytes()); + buf.write(value.as_bytes()); + } + + pub const fn constant(_constant: usize) -> LString<'static> { + panic!("Constants unsupported for this data type") + } +); + +pub struct LStringMeta {} +impl Meta for LStringMeta { + fn name(&self) -> &'static str { + "LString" + } +} + +impl std::fmt::Debug for LString<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + String::from_utf8_lossy(self.buf).fmt(f) + } +} + +impl<'a> LString<'a> { + pub fn to_owned(&self) -> Result { + std::str::from_utf8(self.buf).map(|s| s.to_owned()) + } + + pub fn to_str(&self) -> Result<&str, std::str::Utf8Error> { + std::str::from_utf8(self.buf) + } + + pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + String::from_utf8_lossy(self.buf) + } + + pub fn to_bytes(&self) -> &[u8] { + self.buf + } +} + +impl PartialEq for LString<'_> { + fn eq(&self, other: &Self) -> bool { + self.buf == other.buf + } +} +impl Eq for LString<'_> {} + +impl PartialEq for LString<'_> { + fn eq(&self, other: &str) -> bool { + self.buf == other.as_bytes() + } +} + +impl PartialEq<&str> for LString<'_> { + fn eq(&self, other: &&str) -> bool { + self.buf == other.as_bytes() + } +} + +impl<'a> TryInto<&'a str> for LString<'a> { + type Error = Utf8Error; + fn try_into(self) -> Result<&'a str, Self::Error> { + std::str::from_utf8(self.buf) + } +} + +declare_field_access_fixed_size! { + Meta = UuidMeta, + Inflated = Uuid, + Measure = Uuid, + Builder = Uuid, + Size = 16, + Zero = Uuid::nil(), + + pub const fn meta() -> &'static dyn Meta { + &UuidMeta {} + } + + pub const fn extract(buf: &[u8; 16]) -> Result { + Ok(Uuid::from_u128(::from_be_bytes(*buf))) + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &Uuid) { + buf.write(value.as_bytes().as_slice()) + } + + pub const fn constant(_constant: usize) -> Uuid { + panic!("Constants unsupported for this data type") + } +} + +pub struct UuidMeta {} +impl Meta for UuidMeta { + fn name(&self) -> &'static str { + "Uuid" + } +} + +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +/// An encoded row value. +pub enum Encoded<'a> { + #[default] + Null, + Value(&'a [u8]), +} + +impl<'a> Encoded<'a> { + pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + match self { + Encoded::Null => "".into(), + Encoded::Value(value) => String::from_utf8_lossy(value), + } + } +} + +impl<'a> AsRef> for Encoded<'a> { + fn as_ref(&self) -> &Encoded<'a> { + self + } +} + +declare_field_access! { + Meta = EncodedMeta, + Inflated = Encoded<'a>, + Measure = Encoded<'a>, + Builder = Encoded<'a>, + + pub const fn meta() -> &'static dyn Meta { + &EncodedMeta {} + } + + pub const fn size_of_field_at(buf: &[u8]) -> Result { + const N: usize = std::mem::size_of::(); + if let Some(len) = buf.first_chunk::() { + let len = i32::from_be_bytes(*len); + if len == -1 { + Ok(N) + } else if len < 0 { + Err(ParseError::InvalidData) + } else if buf.len() < len as usize + N { + Err(ParseError::TooShort) + } else { + Ok(len as usize + N) + } + } else { + Err(ParseError::TooShort) + } + } + + pub const fn extract(buf: &[u8]) -> Result, ParseError> { + const N: usize = std::mem::size_of::(); + if let Some((len, array)) = buf.split_first_chunk::() { + let len = i32::from_be_bytes(*len); + if len == -1 && array.is_empty() { + Ok(Encoded::Null) + } else if len < 0 { + Err(ParseError::InvalidData) + } else if array.len() < len as _ { + Err(ParseError::TooShort) + } else { + Ok(Encoded::Value(array)) + } + } else { + Err(ParseError::TooShort) + } + } + + pub const fn measure(value: &Encoded) -> usize { + match value { + Encoded::Null => std::mem::size_of::(), + Encoded::Value(value) => value.len() + std::mem::size_of::(), + } + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &Encoded) { + match value { + Encoded::Null => buf.write(&[0xff, 0xff, 0xff, 0xff]), + Encoded::Value(value) => { + let len: i32 = value.len() as _; + buf.write(&len.to_be_bytes()); + buf.write(value); + } + } + } + + pub const fn constant(_constant: usize) -> Encoded<'static> { + panic!("Constants unsupported for this data type") + } +} + +pub struct EncodedMeta {} +impl Meta for EncodedMeta { + fn name(&self) -> &'static str { + "Encoded" + } +} + +impl<'a> Encoded<'a> {} + +impl PartialEq for Encoded<'_> { + fn eq(&self, other: &str) -> bool { + self == &Encoded::Value(other.as_bytes()) + } +} + +impl PartialEq<&str> for Encoded<'_> { + fn eq(&self, other: &&str) -> bool { + self == &Encoded::Value(other.as_bytes()) + } +} + +impl PartialEq<[u8]> for Encoded<'_> { + fn eq(&self, other: &[u8]) -> bool { + self == &Encoded::Value(other) + } +} + +impl PartialEq<&[u8]> for Encoded<'_> { + fn eq(&self, other: &&[u8]) -> bool { + self == &Encoded::Value(other) + } +} + +pub struct Length(pub i32); + +declare_field_access_fixed_size! { + Meta = LengthMeta, + Inflated = usize, + Measure = i32, + Builder = i32, + Size = 4, + Zero = 0, + + pub const fn meta() -> &'static dyn Meta { + &LengthMeta {} + } + + pub const fn extract(buf: &[u8; 4]) -> Result { + let n = i32::from_be_bytes(*buf); + if n >= 0 { + Ok(n as _) + } else { + Err(ParseError::InvalidData) + } + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &i32) { + FieldAccess::::copy_to_buf(buf, value) + } + + pub const fn constant(value: usize) -> usize { + value + } +} + +impl FieldAccess { + pub fn copy_to_buf_rewind(buf: &mut BufWriter, rewind: usize, value: usize) { + buf.write_rewind(rewind, &(value as i32).to_be_bytes()); + } +} + +// We alias usize here. Note that if this causes trouble in the future we can +// probably work around this by adding a new "const value" function to +// FieldAccess. For now it works! +pub struct LengthMeta {} + +impl Meta for LengthMeta { + fn name(&self) -> &'static str { + "len" + } +} + +pub struct BasicMeta { + _phantom: PhantomData, +} + +impl Meta for BasicMeta { + fn name(&self) -> &'static str { + std::any::type_name::() + } +} + +macro_rules! basic_types { + ($($ty:ty),*) => { + $( + declare_field_access_fixed_size! { + Meta = $ty, + Inflated = $ty, + Measure = $ty, + Builder = $ty, + Size = std::mem::size_of::<$ty>(), + Zero = 0, + + pub const fn meta() -> &'static dyn Meta { + &BasicMeta::<$ty> { _phantom: PhantomData } + } + + pub const fn extract(buf: &[u8; std::mem::size_of::<$ty>()]) -> Result<$ty, ParseError> { + Ok(<$ty>::from_be_bytes(*buf)) + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &$ty) { + buf.write(&<$ty>::to_be_bytes(*value)); + } + + pub const fn constant(value: usize) -> $ty { + value as _ + } + } + )* + }; +} + +basic_types!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); diff --git a/rust/db_proto/src/field_access.rs b/rust/db_proto/src/field_access.rs new file mode 100644 index 00000000000..fc4f96d950e --- /dev/null +++ b/rust/db_proto/src/field_access.rs @@ -0,0 +1,569 @@ +use crate::{BufWriter, Enliven, Meta, ParseError}; + +/// As Rust does not currently support const in traits, we use this struct to +/// provide the const methods. It requires more awkward code, so we make use of +/// macros to generate the code. +/// +/// Note that another consequence is that we have to declare this struct twice: +/// once for this crate, and again when someone tries to instantiate a protocol. +/// The reason for this is that we cannot add additional `impl`s for this `FieldAccess` +/// outside of this crate. Instead, we use a macro to "copy" the existing `impl`s from +/// this crate to the newtype. +pub struct FieldAccess { + _phantom_data: std::marker::PhantomData, +} + +/// Delegates to a concrete [`FieldAccess`] but as a non-const trait. This is +/// used for performing extraction in iterators. +pub trait FieldAccessArray: Enliven { + const META: &'static dyn Meta; + fn size_of_field_at(buf: &[u8]) -> Result; + fn extract(buf: &[u8]) -> Result<::WithLifetime<'_>, ParseError>; + fn copy_to_buf(buf: &mut BufWriter, value: &Self::ForBuilder<'_>); +} + +/// A trait for types which are fixed-size, used to provide a `get` implementation +/// in arrays and iterators. +pub trait FixedSize: Enliven { + const SIZE: usize; + /// Extract this type from the given buffer, assuming that enough bytes are available. + fn extract_infallible(buf: &[u8]) -> ::WithLifetime<'_>; +} + +/// Declares a field access for a given type which is variably-sized. +#[macro_export] +#[doc(hidden)] +macro_rules! declare_field_access { + ( + Meta = $meta:ty, + Inflated = $inflated:ty, + Measure = $measured:ty, + Builder = $builder:ty, + + pub const fn meta() -> &'static dyn Meta + $meta_body:block + + pub const fn size_of_field_at($size_of_arg0:ident : &[u8]) -> Result + $size_of:block + + pub const fn extract($extract_arg0:ident : &[u8]) -> Result<$extract_ret:ty, ParseError> + $extract:block + + pub const fn measure($measure_arg0:ident : &$measure_param:ty) -> usize + $measure:block + + pub fn copy_to_buf($copy_arg0:ident : &mut BufWriter, $copy_arg1:ident : &$value_param:ty) + $copy:block + + pub const fn constant($constant_arg0:ident : usize) -> $constant_ret:ty + $constant:block + ) => { + impl Enliven for $meta { + type WithLifetime<'a> = $inflated; + type ForMeasure<'a> = $measured; + type ForBuilder<'a> = $builder; + } + + impl FieldAccess<$meta> { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + $meta_body + } + #[inline(always)] + pub const fn size_of_field_at($size_of_arg0: &[u8]) -> Result { + $size_of + } + #[inline(always)] + pub const fn extract($extract_arg0: &[u8]) -> Result<$extract_ret, ParseError> { + $extract + } + #[inline(always)] + pub const fn measure($measure_arg0: &$measure_param) -> usize { + $measure + } + #[inline(always)] + pub fn copy_to_buf($copy_arg0: &mut BufWriter, $copy_arg1: &$value_param) { + $copy + } + #[inline(always)] + pub const fn constant($constant_arg0: usize) -> $constant_ret { + $constant + } + } + + $crate::field_access!($crate::FieldAccess, $meta); + $crate::array_access!(variable, $crate::FieldAccess, $meta); + }; +} + +/// Declares a field access for a given type which is fixed-size. Fixed-size +/// fields have simpler extraction logic, and support mapping to Rust arrays. +#[macro_export] +#[doc(hidden)] +macro_rules! declare_field_access_fixed_size { + ( + Meta = $meta:ty, + Inflated = $inflated:ty, + Measure = $measured:ty, + Builder = $builder:ty, + Size = $size:expr, + Zero = $zero:expr, + + pub const fn meta() -> &'static dyn Meta + $meta_body:block + + pub const fn extract($extract_arg0:ident : &$extract_type:ty) -> Result<$extract_ret:ty, ParseError> + $extract:block + + pub fn copy_to_buf($copy_arg0:ident : &mut BufWriter, $copy_arg1:ident : &$value_param:ty) + $copy:block + + pub const fn constant($constant_arg0:ident : usize) -> $constant_ret:ty + $constant:block + ) => { + impl Enliven for $meta { + type WithLifetime<'a> = $inflated; + type ForMeasure<'a> = $measured; + type ForBuilder<'a> = $builder; + } + + impl FieldAccess<$meta> { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + $meta_body + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + if let Ok(_) = Self::extract(buf) { + Ok($size) + } else { + Err(ParseError::TooShort) + } + } + #[inline(always)] + pub const fn extract($extract_arg0: &[u8]) -> Result<$extract_ret, ParseError> { + if let Some(chunk) = $extract_arg0.first_chunk() { + FieldAccess::<$meta>::extract_exact(chunk) + } else { + Err(ParseError::TooShort) + } + } + #[inline(always)] + pub const fn extract_exact( + $extract_arg0: &[u8; $size], + ) -> Result<$extract_ret, ParseError> { + $extract + } + #[inline(always)] + pub const fn measure(_: &$measured) -> usize { + $size + } + #[inline(always)] + pub fn copy_to_buf($copy_arg0: &mut BufWriter, $copy_arg1: &$value_param) { + $copy + } + #[inline(always)] + pub const fn constant($constant_arg0: usize) -> $constant_ret { + $constant + } + } + + impl $crate::FixedSize for $meta { + const SIZE: usize = std::mem::size_of::<$inflated>(); + #[inline(always)] + fn extract_infallible(buf: &[u8]) -> $inflated { + FieldAccess::<$meta>::extract(buf).unwrap() + } + } + + impl Enliven for $crate::meta::FixedArray { + type WithLifetime<'a> = [$inflated; S]; + type ForMeasure<'a> = [$measured; S]; + type ForBuilder<'a> = [$builder; S]; + } + + #[allow(unused)] + impl FieldAccess<$crate::meta::FixedArray> { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + &$crate::meta::FixedArray:: { + _phantom: PhantomData, + } + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + let size = $size * S; + if size > buf.len() { + Err($crate::ParseError::TooShort) + } else { + Ok(size) + } + } + #[inline(always)] + pub const fn measure(_: &[$measured; S]) -> usize { + ($size * (S)) + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> Result<[$inflated; S], $crate::ParseError> { + let mut out: [$inflated; S] = [const { $zero }; S]; + let mut i = 0; + loop { + if i == S { + break; + } + (out[i], buf) = if let Some((bytes, rest)) = buf.split_first_chunk() { + match FieldAccess::<$meta>::extract_exact(bytes) { + Ok(value) => (value, rest), + Err(e) => return Err(e), + } + } else { + return Err($crate::ParseError::TooShort); + }; + i += 1; + } + Ok(out) + } + #[inline(always)] + pub fn copy_to_buf(mut buf: &mut BufWriter, value: &[$builder; S]) { + if !buf.test(std::mem::size_of::<$builder>() * S) { + return; + } + for n in value { + FieldAccess::<$meta>::copy_to_buf(buf, n); + } + } + } + + impl FieldAccessArray for $crate::meta::FixedArray { + const META: &'static dyn Meta = FieldAccess::<$meta>::meta(); + #[inline(always)] + fn size_of_field_at(buf: &[u8]) -> Result { + // TODO: needs to verify the values as well + FieldAccess::<$meta>::size_of_field_at(buf).map(|size| size * S) + } + #[inline(always)] + fn extract(mut buf: &[u8]) -> Result<[$inflated; S], ParseError> { + let mut out = [$zero; S]; + for i in 0..S { + (out[i], buf) = if let Some((bytes, rest)) = buf.split_first_chunk() { + (FieldAccess::<$meta>::extract_exact(bytes)?, rest) + } else { + return Err(ParseError::TooShort); + }; + } + Ok(out) + } + #[inline(always)] + fn copy_to_buf(buf: &mut BufWriter, value: &[$builder; S]) { + for n in value { + FieldAccess::<$meta>::copy_to_buf(buf, n); + } + } + } + + $crate::field_access!($crate::FieldAccess, $meta); + $crate::array_access!(fixed, $crate::FieldAccess, $meta); + }; +} + +/// Delegate to the concrete [`FieldAccess`] for each type we want to extract. +#[macro_export] +#[doc(hidden)] +macro_rules! field_access { + ($acc:ident :: FieldAccess, $ty:ty) => { + impl $crate::FieldAccessArray for $ty { + const META: &'static dyn $crate::Meta = $acc::FieldAccess::<$ty>::meta(); + #[inline(always)] + fn size_of_field_at(buf: &[u8]) -> Result { + $acc::FieldAccess::<$ty>::size_of_field_at(buf) + } + #[inline(always)] + fn extract( + buf: &[u8], + ) -> Result<::WithLifetime<'_>, $crate::ParseError> { + $acc::FieldAccess::<$ty>::extract(buf) + } + #[inline(always)] + fn copy_to_buf( + buf: &mut $crate::BufWriter, + value: &<$ty as $crate::Enliven>::ForBuilder<'_>, + ) { + $acc::FieldAccess::<$ty>::copy_to_buf(buf, value) + } + } + }; +} + +/// Define array accesses for inflated, strongly-typed arrays of both +/// zero-terminated and length-delimited types. +#[macro_export] +#[doc(hidden)] +macro_rules! array_access { + (fixed, $acc:ident :: FieldAccess, $ty:ty) => { + $crate::array_access!(fixed, $acc :: FieldAccess, $ty | u8 i16 u16 i32 u32); + }; + (variable, $acc:ident :: FieldAccess, $ty:ty) => { + $crate::array_access!(variable, $acc :: FieldAccess, $ty | u8 i16 u16 i32 u32); + }; + (fixed, $acc:ident :: FieldAccess, $ty:ty | $($len:ty)*) => { + $( + #[allow(unused)] + impl FieldAccess<$crate::meta::Array<$len, $ty>> { + pub const fn meta() -> &'static dyn Meta { + &$crate::meta::Array::<$len, $ty> { _phantom: PhantomData } + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + const N: usize = <$ty as $crate::FixedSize>::SIZE; + const L: usize = std::mem::size_of::<$len>(); + if let Some(len) = buf.first_chunk::() { + let len_value = <$len>::from_be_bytes(*len); + #[allow(unused_comparisons)] + if len_value < 0 { + return Err($crate::ParseError::InvalidData); + } + let mut byte_len = len_value as usize; + byte_len = match byte_len.checked_mul(N) { + Some(l) => l, + None => return Err($crate::ParseError::TooShort), + }; + byte_len = match byte_len.checked_add(L) { + Some(l) => l, + None => return Err($crate::ParseError::TooShort), + }; + if buf.len() < byte_len { + Err($crate::ParseError::TooShort) + } else { + Ok(byte_len) + } + } else { + Err($crate::ParseError::TooShort) + } + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> Result<$crate::Array<$len, $ty>, $crate::ParseError> { + const N: usize = <$ty as $crate::FixedSize>::SIZE; + const L: usize = std::mem::size_of::<$len>(); + if let Some((len, array)) = buf.split_first_chunk::() { + let len_value = <$len>::from_be_bytes(*len); + #[allow(unused_comparisons)] + if len_value < 0 { + return Err($crate::ParseError::InvalidData); + } + let mut byte_len = len_value as usize; + byte_len = match byte_len.checked_mul(N) { + Some(l) => l, + None => return Err($crate::ParseError::TooShort), + }; + byte_len = match byte_len.checked_add(L) { + Some(l) => l, + None => return Err($crate::ParseError::TooShort), + }; + if buf.len() < byte_len { + Err($crate::ParseError::TooShort) + } else { + Ok($crate::Array::new(array, len_value as u32)) + } + } else { + Err($crate::ParseError::TooShort) + } + } + #[inline(always)] + pub const fn measure<'a>(buffer: &'a[<$ty as $crate::Enliven>::ForMeasure<'a>]) -> usize { + buffer.len() * std::mem::size_of::<$ty>() + std::mem::size_of::<$len>() + } + #[inline(always)] + pub fn copy_to_buf<'a>(mut buf: &mut BufWriter, value: &'a[<$ty as $crate::Enliven>::ForBuilder<'a>]) { + let size: usize = std::mem::size_of::<$ty>() * value.len() + std::mem::size_of::<$len>(); + if !buf.test(size) { + return; + } + buf.write(&<$len>::to_be_bytes(value.len() as _)); + for n in value { + $acc::FieldAccess::<$ty>::copy_to_buf(buf, n); + } + } + #[inline(always)] + pub const fn constant(value: usize) -> $crate::Array<'static, $len, $ty> { + panic!("Constants unsupported for this data type") + } + } + )* + + #[allow(unused)] + impl $acc::FieldAccess<$crate::meta::ZTArray<$ty>> { + pub const fn meta() -> &'static dyn $crate::Meta { + &$crate::meta::ZTArray::<$ty> { _phantom: std::marker::PhantomData } + } + #[inline] + pub const fn size_of_field_at(mut buf: &[u8]) -> Result { + let mut size = 1; + loop { + if buf.is_empty() { + return Err($crate::ParseError::TooShort); + } + if buf[0] == 0 { + return Ok(size); + } + let elem_size = match $acc::FieldAccess::<$ty>::size_of_field_at(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + buf = buf.split_at(elem_size).1; + size += elem_size; + } + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> Result<$crate::ZTArray<$ty>, $crate::ParseError> { + Ok($crate::ZTArray::new(buf)) + } + #[inline] + pub const fn measure<'a>(buffer: &'a[<$ty as $crate::Enliven>::ForMeasure<'a>]) -> usize { + let mut size = 1; + let mut index = 0; + loop { + if index + 1 > buffer.len() { + break; + } + let item = &buffer[index]; + size += $acc::FieldAccess::<$ty>::measure(item); + index += 1; + } + size + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut $crate::BufWriter, value: &[<$ty as $crate::Enliven>::ForBuilder<'_>]) { + for elem in value { + $acc::FieldAccess::<$ty>::copy_to_buf(buf, elem); + } + buf.write_u8(0); + } + #[inline(always)] + pub const fn constant(value: usize) -> $crate::ZTArray<'static, $ty> { + panic!("Constants unsupported for this data type") + } + } + }; + (variable, $acc:ident :: FieldAccess, $ty:ty | $($len:ty)*) => { + $( + #[allow(unused)] + impl $acc::FieldAccess<$crate::meta::Array<$len, $ty>> { + pub const fn meta() -> &'static dyn $crate::Meta { + &$crate::meta::Array::<$len, $ty> { _phantom: std::marker::PhantomData } + } + #[inline] + pub const fn size_of_field_at(mut buf: &[u8]) -> Result { + let mut size = std::mem::size_of::<$len>(); + let mut len = match $acc::FieldAccess::<$len>::extract(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + #[allow(unused_comparisons)] + if len < 0 { + return Err($crate::ParseError::InvalidData); + } + buf = buf.split_at(size).1; + loop { + if len <= 0 { + break; + } + len -= 1; + let elem_size = match $acc::FieldAccess::<$ty>::size_of_field_at(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + buf = buf.split_at(elem_size).1; + size += elem_size; + } + Ok(size) + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> Result<$crate::Array<'_, $len, $ty>, $crate::ParseError> { + match $acc::FieldAccess::<$len>::extract(buf) { + Ok(len) => Ok($crate::Array::new(buf.split_at(std::mem::size_of::<$len>()).1, len as u32)), + Err(e) => Err(e) + } + } + #[inline] + pub const fn measure<'a>(buffer: &'a[<$ty as $crate::Enliven>::ForMeasure<'a>]) -> usize { + let mut size = std::mem::size_of::<$len>(); + let mut index = 0; + loop { + if index + 1 > buffer.len() { + break; + } + let item = &buffer[index]; + size += $acc::FieldAccess::<$ty>::measure(item); + index += 1; + } + size + } + #[inline(always)] + pub fn copy_to_buf<'a>(buf: &mut $crate::BufWriter, value: &'a[<$ty as $crate::Enliven>::ForBuilder<'a>]) { + buf.write(&<$len>::to_be_bytes(value.len() as _)); + for elem in value { + $acc::FieldAccess::<$ty>::copy_to_buf(buf, elem); + } + } + #[inline(always)] + pub const fn constant(value: usize) -> $crate::Array<'static, $len, $ty> { + panic!("Constants unsupported for this data type") + } + } + )* + + #[allow(unused)] + impl $acc::FieldAccess<$crate::meta::ZTArray<$ty>> { + pub const fn meta() -> &'static dyn $crate::Meta { + &$crate::meta::ZTArray::<$ty> { _phantom: std::marker::PhantomData } + } + #[inline] + pub const fn size_of_field_at(mut buf: &[u8]) -> Result { + let mut size = 1; + loop { + if buf.is_empty() { + return Err($crate::ParseError::TooShort); + } + if buf[0] == 0 { + return Ok(size); + } + let elem_size = match $acc::FieldAccess::<$ty>::size_of_field_at(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + buf = buf.split_at(elem_size).1; + size += elem_size; + } + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> Result<$crate::ZTArray<$ty>, $crate::ParseError> { + Ok($crate::ZTArray::new(buf)) + } + #[inline] + pub const fn measure<'a>(buffer: &'a[<$ty as $crate::Enliven>::ForMeasure<'a>]) -> usize { + let mut size = 1; + let mut index = 0; + loop { + if index + 1 > buffer.len() { + break; + } + let item = &buffer[index]; + size += $acc::FieldAccess::<$ty>::measure(item); + index += 1; + } + size + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut $crate::BufWriter, value: &[<$ty as $crate::Enliven>::ForBuilder<'_>]) { + for elem in value { + $acc::FieldAccess::<$ty>::copy_to_buf(buf, elem); + } + buf.write_u8(0); + } + #[inline(always)] + pub const fn constant(value: usize) -> $crate::ZTArray<'static, $ty> { + panic!("Constants unsupported for this data type") + } + } + }; +} diff --git a/rust/pgrust/src/protocol/gen.rs b/rust/db_proto/src/gen.rs similarity index 71% rename from rust/pgrust/src/protocol/gen.rs rename to rust/db_proto/src/gen.rs index 19dca193326..15e4e0cf6c6 100644 --- a/rust/pgrust/src/protocol/gen.rs +++ b/rust/db_proto/src/gen.rs @@ -10,7 +10,7 @@ /// level of the macro adds its own layer of processing and metadata /// accumulation, eventually leading to the final output. /// -/// The `struct_elaborate!` macro is a tool designed to perform an initial +/// The `$crate::struct_elaborate!` macro is a tool designed to perform an initial /// parsing pass on a Rust `struct`, enriching it with metadata to facilitate /// further macro processing. It begins by extracting and analyzing the fields /// of the `struct`, capturing associated metadata such as attributes and types. @@ -35,6 +35,8 @@ /// it reconstructs an enriched `struct`-like data blob using the accumulated /// metadata. It then passes this enriched `struct` to the `next` macro for /// further processing. +#[doc(hidden)] +#[macro_export] macro_rules! struct_elaborate { ( $next:ident $( ($($next_args:tt)*) )? => @@ -50,7 +52,7 @@ macro_rules! struct_elaborate { ) => { // paste! is necessary here because it allows us to re-interpret a "ty" // as an explicit type pattern below. - struct_elaborate!(__builder_type__ + $crate::struct_elaborate!(__builder_type__ // Pass down a "fixed offset" flag that indicates whether the // current field is at a fixed offset. This gets reset to // `no_fixed_offset` when we hit a variable-sized field. @@ -74,66 +76,66 @@ macro_rules! struct_elaborate { // End of push-down automation - jumps to `__finalize__` (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields() accum($($faccum:tt)*) original($($original:tt)*)) => { - struct_elaborate!(__finalize__ accum($($faccum)*) original($($original)*)); + $crate::struct_elaborate!(__finalize__ accum($($faccum)*) original($($original)*)); }; // Skip __builder_value__ for 'len' (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(len)(len), value(), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::protocol::meta::Length), size(fixed=fixed), value(auto=auto), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::meta::Length), size(fixed=fixed), value(auto=auto), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(len)(len), value($($value:tt)+), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::protocol::meta::Length), size(fixed=fixed), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::meta::Length), size(fixed=fixed), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); }; // Pattern match on known fixed-sized types and mark them as `size(fixed=fixed)` - (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type([u8; $len:literal])($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type([$elem:ty; $len:literal])($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($crate::meta::FixedArray<$len, $elem>), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(u8)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr)fields([type(i16)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(i32)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(u32)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(u64)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(Uuid)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; // Fallback for other types - variable sized (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type($ty:ty)($ty2:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>no_fixed_offset $fixed_expr=>(0)) fields([type($ty), size(variable=variable), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>no_fixed_offset $fixed_expr=>(0)) fields([type($ty), size(variable=variable), $($rest)*] $($frest)*) $($srest)*); }; // Next, mark the presence or absence of a value (__builder_value__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ type($ty:ty), size($($size:tt)*), value(), $($rest:tt)* ] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value(no_value=no_value), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value(no_value=no_value), $($rest)*] $($frest)*) $($srest)*); }; (__builder_value__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ type($ty:ty), size($($size:tt)*), value($($value:tt)+), $($rest:tt)* ] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); }; // Next, handle missing docs (__builder_docs__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ type($ty:ty), size($($size:tt)*), value($($value:tt)*), docs(), name($field:ident), $($rest:tt)* ] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value($($value)*), docs(concat!("`", stringify!($field), "` field.")), name($field), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value($($value)*), docs(concat!("`", stringify!($field), "` field.")), name($field), $($rest)*] $($frest)*) $($srest)*); }; (__builder_docs__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ type($ty:ty), size($($size:tt)*), value($($value:tt)*), docs($($fdoc:literal)+), $($rest:tt)* ] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value($($value)*), docs(concat!($($fdoc)+)), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value($($value)*), docs(concat!($($fdoc)+)), $($rest)*] $($frest)*) $($srest)*); }; @@ -141,7 +143,7 @@ macro_rules! struct_elaborate { (__builder__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ type($ty:ty), size($($size:tt)*), value($($value:tt)*), docs($fdoc:expr), name($field:ident), $($rest:tt)* ] $($frest:tt)*) accum($($faccum:tt)*) original($($original:tt)*)) => { - struct_elaborate!(__builder_type__ fixed($fixed_new $fixed_expr_new) fields($($frest)*) accum( + $crate::struct_elaborate!(__builder_type__ fixed($fixed_new $fixed_expr_new) fields($($frest)*) accum( $($faccum)* { name($field), @@ -169,19 +171,50 @@ macro_rules! struct_elaborate { } } -macro_rules! protocol { +/// Generates a protocol definition from a Rust-like DSL. +/// +/// ``` +/// struct Foo { +/// bar: u8, +/// baz: u16, +/// } +/// ``` +#[doc(hidden)] +#[macro_export] +macro_rules! __protocol { ($( $( #[ $sdoc:meta ] )* struct $name:ident $(: $super:ident)? { $($struct:tt)+ } )+) => { + mod access { + #![allow(unused)] + + /// This struct is specialized for each type we want to extract data from. We + /// have to do it this way to work around Rust's lack of const specialization. + pub struct FieldAccess { + _phantom_data: std::marker::PhantomData, + } + + $crate::field_access_copy!{basic $crate::FieldAccess, self::FieldAccess, + i8, u8, i16, u16, i32, u32, i64, u64, i128, u128, + $crate::meta::Uuid + } + $crate::field_access_copy!{$crate::FieldAccess, self::FieldAccess, + $crate::meta::ZTString, + $crate::meta::LString, + $crate::meta::Rest, + $crate::meta::Encoded, + $crate::meta::Length + } + } + $( - paste::paste!( + $crate::paste!( #[allow(unused_imports)] pub(crate) mod [<__ $name:lower>] { + use $crate::{meta::*, protocol_builder}; use super::meta::*; - use $crate::protocol::meta::*; - use $crate::protocol::gen::*; - struct_elaborate!(protocol_builder(__struct__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); - struct_elaborate!(protocol_builder(__meta__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); - struct_elaborate!(protocol_builder(__measure__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); - struct_elaborate!(protocol_builder(__builder__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + $crate::struct_elaborate!(protocol_builder(__struct__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + $crate::struct_elaborate!(protocol_builder(__meta__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + $crate::struct_elaborate!(protocol_builder(__measure__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + $crate::struct_elaborate!(protocol_builder(__builder__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); } ); )+ @@ -189,7 +222,7 @@ macro_rules! protocol { pub mod data { #![allow(unused_imports)] $( - paste::paste!( + $crate::paste!( pub use super::[<__ $name:lower>]::$name; ); )+ @@ -197,7 +230,7 @@ macro_rules! protocol { pub mod meta { #![allow(unused_imports)] $( - paste::paste!( + $crate::paste!( pub use super::[<__ $name:lower>]::[<$name Meta>] as $name; ); )+ @@ -205,7 +238,7 @@ macro_rules! protocol { /// A slice containing the metadata references for all structs in /// this definition. #[allow(unused)] - pub const ALL: &'static [&'static dyn $crate::protocol::Meta] = &[ + pub const ALL: &'static [&'static dyn $crate::Meta] = &[ $( &$name {} ),* @@ -214,7 +247,7 @@ macro_rules! protocol { pub mod builder { #![allow(unused_imports)] $( - paste::paste!( + $crate::paste!( pub use super::[<__ $name:lower>]::[<$name Builder>] as $name; ); )+ @@ -222,7 +255,7 @@ macro_rules! protocol { pub mod measure { #![allow(unused_imports)] $( - paste::paste!( + $crate::paste!( pub use super::[<__ $name:lower>]::[<$name Measure>] as $name; ); )+ @@ -230,6 +263,11 @@ macro_rules! protocol { }; } +#[doc(inline)] +pub use __protocol as protocol; + +#[macro_export] +#[doc(hidden)] macro_rules! r#if { (__is_empty__ [] {$($true:tt)*} else {$($false:tt)*}) => { $($true)* @@ -244,6 +282,8 @@ macro_rules! r#if { }; } +#[doc(hidden)] +#[macro_export] macro_rules! protocol_builder { (__struct__, struct $name:ident { super($($super:ident)?), @@ -258,7 +298,7 @@ macro_rules! protocol_builder { $($rest:tt)* },)*), }) => { - paste::paste!( + $crate::paste!( /// Our struct we are building. type S<'a> = $name<'a>; /// The meta-struct for the struct we are building. @@ -276,6 +316,7 @@ macro_rules! protocol_builder { $( " (value = `", stringify!($value), "`)", )? "\n\n" )* )] + #[derive(Copy, Clone)] pub struct $name<'a> { /// Our zero-copy buffer. #[doc(hidden)] @@ -320,23 +361,23 @@ macro_rules! protocol_builder { $( $( - let Ok(val) = $crate::protocol::FieldAccess::<$type>::extract(buf.split_at(offset).1) else { + let Ok(val) = super::access::FieldAccess::<$type>::extract(buf.split_at(offset).1) else { return false; }; if val as usize != $value as usize { return false; } )? - offset += std::mem::size_of::<$type>(); + offset += std::mem::size_of::<<$type as $crate::Enliven>::ForBuilder<'static>>(); )* true } $( - pub const fn can_cast(parent: &<$super as $crate::protocol::Enliven>::WithLifetime<'a>) -> bool { + pub const fn can_cast(parent: &::WithLifetime<'a>) -> bool { Self::is_buffer(parent.__buf) } - pub const fn try_new(parent: &<$super as $crate::protocol::Enliven>::WithLifetime<'a>) -> Option { + pub const fn try_new(parent: &::WithLifetime<'a>) -> Option { if Self::can_cast(parent) { // TODO let Ok(value) = Self::new(parent.__buf) else { @@ -351,13 +392,13 @@ macro_rules! protocol_builder { /// Creates a new instance of this struct from a given buffer. #[inline] - pub const fn new(mut buf: &'a [u8]) -> Result { + pub const fn new(mut buf: &'a [u8]) -> Result { let mut __field_offsets = [0; Meta::FIELD_COUNT + 1]; let mut offset = 0; let mut index = 0; $( __field_offsets[index] = offset; - offset += match $crate::protocol::FieldAccess::<$type>::size_of_field_at(buf.split_at(offset).1) { + offset += match super::access::FieldAccess::<$type>::size_of_field_at(buf.split_at(offset).1) { Ok(n) => n, Err(e) => return Err(e), }; @@ -371,7 +412,7 @@ macro_rules! protocol_builder { }) } - pub fn to_vec(&self) -> Vec { + pub fn to_vec(self) -> Vec { self.__buf.to_vec() } @@ -379,14 +420,14 @@ macro_rules! protocol_builder { #[doc = $fdoc] #[allow(unused)] #[inline] - pub const fn $field<'s>(&'s self) -> <$type as $crate::protocol::Enliven>::WithLifetime<'a> where 's : 'a { + pub const fn $field<'s>(&'s self) -> <$type as $crate::Enliven>::WithLifetime<'a> where 's : 'a { // Perform a const buffer extraction operation let offset1 = self.__field_offsets[F::$field as usize]; let offset2 = self.__field_offsets[F::$field as usize + 1]; let (_, buf) = self.__buf.split_at(offset1); let (buf, _) = buf.split_at(offset2 - offset1); // This will not panic: we've confirmed the validity of the buffer when sizing - let Ok(value) = $crate::protocol::FieldAccess::<$type>::extract(buf) else { + let Ok(value) = super::access::FieldAccess::<$type>::extract(buf) else { panic!(); }; value @@ -409,7 +450,7 @@ macro_rules! protocol_builder { $($rest:tt)* },)*), }) => { - paste::paste!( + $crate::paste!( $( #[$sdoc] )? #[allow(unused)] #[derive(Debug, Default)] @@ -429,25 +470,25 @@ macro_rules! protocol_builder { #[allow(unused)] impl Meta { pub const FIELD_COUNT: usize = [$(stringify!($field)),*].len(); - $($(pub const [<$field:upper _VALUE>]: $type = $crate::protocol::FieldAccess::<$type>::constant($value as usize);)?)* + $($(pub const [<$field:upper _VALUE>]: <$type as $crate::Enliven>::WithLifetime<'static> = super::access::FieldAccess::<$type>::constant($value as usize);)?)* } - impl $crate::protocol::Meta for Meta { + impl $crate::Meta for Meta { fn name(&self) -> &'static str { stringify!($name) } - fn relations(&self) -> &'static [($crate::protocol::MetaRelation, &'static dyn $crate::protocol::Meta)] { - r#if!(__is_empty__ [$($super)?] { - const RELATIONS: &'static [($crate::protocol::MetaRelation, &'static dyn $crate::protocol::Meta)] = &[ + fn relations(&self) -> &'static [($crate::MetaRelation, &'static dyn $crate::Meta)] { + $crate::r#if!(__is_empty__ [$($super)?] { + const RELATIONS: &'static [($crate::MetaRelation, &'static dyn $crate::Meta)] = &[ $( - ($crate::protocol::MetaRelation::Field(stringify!($field)), $crate::protocol::FieldAccess::<$type>::meta()) + ($crate::MetaRelation::Field(stringify!($field)), super::access::FieldAccess::<$type>::meta()) ),* ]; } else { - const RELATIONS: &'static [($crate::protocol::MetaRelation, &'static dyn $crate::protocol::Meta)] = &[ - ($crate::protocol::MetaRelation::Parent, $crate::protocol::FieldAccess::<$($super)?>::meta()), + const RELATIONS: &'static [($crate::MetaRelation, &'static dyn $crate::Meta)] = &[ + ($crate::MetaRelation::Parent, super::access::FieldAccess::::meta()), $( - ($crate::protocol::MetaRelation::Field(stringify!($field)), $crate::protocol::FieldAccess::<$type>::meta()) + ($crate::MetaRelation::Field(stringify!($field)), super::access::FieldAccess::<$type>::meta()) ),* ]; }); @@ -459,9 +500,9 @@ macro_rules! protocol_builder { protocol_builder!(__meta__, $fixed($fixed_expr) $field $type); )* - impl $crate::protocol::StructMeta for Meta { + impl $crate::StructMeta for Meta { type Struct<'a> = S<'a>; - fn new(buf: &[u8]) -> Result, $crate::protocol::ParseError> { + fn new(buf: &[u8]) -> Result, $crate::ParseError> { S::new(buf) } fn to_vec(s: &Self::Struct<'_>) -> Vec { @@ -469,27 +510,27 @@ macro_rules! protocol_builder { } } - impl $crate::protocol::Enliven for Meta { + impl $crate::Enliven for Meta { type WithLifetime<'a> = S<'a>; type ForMeasure<'a> = M<'a>; type ForBuilder<'a> = B<'a>; } #[allow(unused)] - impl $crate::protocol::FieldAccess { + impl super::access::FieldAccess { #[inline(always)] pub const fn name() -> &'static str { stringify!($name) } #[inline(always)] - pub const fn meta() -> &'static dyn $crate::protocol::Meta { + pub const fn meta() -> &'static dyn $crate::Meta { &Meta {} } #[inline] - pub const fn size_of_field_at(buf: &[u8]) -> Result { + pub const fn size_of_field_at(buf: &[u8]) -> Result { let mut offset = 0; $( - offset += match $crate::protocol::FieldAccess::<$type>::size_of_field_at(buf.split_at(offset).1) { + offset += match super::access::FieldAccess::<$type>::size_of_field_at(buf.split_at(offset).1) { Ok(n) => n, Err(e) => return Err(e), }; @@ -497,7 +538,7 @@ macro_rules! protocol_builder { Ok(offset) } #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result<$name<'_>, $crate::protocol::ParseError> { + pub const fn extract(buf: &[u8]) -> Result<$name<'_>, $crate::ParseError> { $name::new(buf) } #[inline(always)] @@ -505,21 +546,18 @@ macro_rules! protocol_builder { measure.measure() } #[inline(always)] - pub fn copy_to_buf(buf: &mut $crate::protocol::writer::BufWriter, builder: &B) { - builder.copy_to_buf(buf) - } - #[inline(always)] - pub fn copy_to_buf_ref(buf: &mut $crate::protocol::writer::BufWriter, builder: &B) { + pub fn copy_to_buf(buf: &mut $crate::BufWriter, builder: &B) { builder.copy_to_buf(buf) } } - $crate::protocol::field_access!{[<$name Meta>]} - $crate::protocol::arrays::array_access!{[<$name Meta>]} + use super::access::FieldAccess as FieldAccess; + $crate::field_access!{self::FieldAccess, [<$name Meta>]} + $crate::array_access!{variable, self::FieldAccess, [<$name Meta>]} ); }; - (__meta__, fixed_offset($fixed_expr:expr) $field:ident $crate::protocol::meta::Length) => { - impl $crate::protocol::StructLength for Meta { + (__meta__, fixed_offset($fixed_expr:expr) $field:ident $crate::meta::Length) => { + impl $crate::StructLength for Meta { fn length_field_of(of: &Self::Struct<'_>) -> usize { of.$field() } @@ -528,7 +566,7 @@ macro_rules! protocol_builder { } } }; - (__meta__, $fixed:ident($fixed_expr:expr) $field:ident $crate::protocol::meta::Rest) => { + (__meta__, $fixed:ident($fixed_expr:expr) $field:ident $crate::meta::Rest) => { }; (__meta__, $fixed:ident($fixed_expr:expr) $field:ident $any:ty) => { @@ -546,8 +584,8 @@ macro_rules! protocol_builder { $($rest:tt)* },)*), }) => { - paste::paste!( - r#if!(__is_empty__ [$($($variable_marker)?)*] { + $crate::paste!( + $crate::r#if!(__is_empty__ [$($($variable_marker)?)*] { $( #[$sdoc] )? // No variable-sized fields #[derive(Default, Eq, PartialEq)] @@ -563,7 +601,7 @@ macro_rules! protocol_builder { // pattern. $($( #[doc = $fdoc] - pub $field: r#if!(__has__ [$variable_marker] {<$type as $crate::protocol::Enliven>::ForMeasure<'a>}), + pub $field: $crate::r#if!(__has__ [$variable_marker] {<$type as $crate::Enliven>::ForMeasure<'a>}), )?)* } }); @@ -572,8 +610,8 @@ macro_rules! protocol_builder { pub const fn measure(&self) -> usize { let mut size = 0; $( - r#if!(__has__ [$($variable_marker)?] { size += $crate::protocol::FieldAccess::<$type>::measure(&self.$field); }); - r#if!(__has__ [$($fixed_marker)?] { size += std::mem::size_of::<$type>(); }); + $crate::r#if!(__has__ [$($variable_marker)?] { size += super::access::FieldAccess::<$type>::measure(&self.$field); }); + $crate::r#if!(__has__ [$($fixed_marker)?] { size += std::mem::size_of::<<$type as $crate::Enliven>::ForBuilder<'static>>(); }); )* size } @@ -593,8 +631,8 @@ macro_rules! protocol_builder { $($rest:tt)* },)*), }) => { - paste::paste!( - r#if!(__is_empty__ [$($($no_value)?)*] { + $crate::paste!( + $crate::r#if!(__is_empty__ [$($($no_value)?)*] { $( #[$sdoc] )? // No unfixed-value fields #[derive(::derive_more::Debug, Default, Eq, PartialEq)] @@ -611,30 +649,30 @@ macro_rules! protocol_builder { // somehow use $no_value in the remainder of the pattern. $($( #[doc = $fdoc] - pub $field: r#if!(__has__ [$no_value] {<$type as $crate::protocol::Enliven>::ForBuilder<'a>}), + pub $field: $crate::r#if!(__has__ [$no_value] {<$type as $crate::Enliven>::ForBuilder<'a>}), )?)* } }); impl B<'_> { #[allow(unused)] - pub fn copy_to_buf(&self, buf: &mut $crate::protocol::writer::BufWriter) { + pub fn copy_to_buf(&self, buf: &mut $crate::BufWriter) { $( - r#if!(__is_empty__ [$($value)?] { - r#if!(__is_empty__ [$($auto)?] { - $crate::protocol::FieldAccess::<$type>::copy_to_buf(buf, self.$field); + $crate::r#if!(__is_empty__ [$($value)?] { + $crate::r#if!(__is_empty__ [$($auto)?] { + <$type as $crate::FieldAccessArray>::copy_to_buf(buf, &self.$field); } else { let auto_offset = buf.size(); - $crate::protocol::FieldAccess::<$type>::copy_to_buf(buf, 0); + <$type as $crate::FieldAccessArray>::copy_to_buf(buf, &0); }); } else { - $crate::protocol::FieldAccess::<$type>::copy_to_buf(buf, $($value)? as usize as _); + <$type as $crate::FieldAccessArray>::copy_to_buf(buf, &($($value)? as usize as _)); }); )* $( - r#if!(__has__ [$($auto)?] { - $crate::protocol::FieldAccess::::copy_to_buf_rewind(buf, auto_offset, buf.size() - auto_offset); + $crate::r#if!(__has__ [$($auto)?] { + $crate::FieldAccess::<$crate::meta::Length>::copy_to_buf_rewind(buf, auto_offset, buf.size() - auto_offset); }); )* @@ -645,7 +683,7 @@ macro_rules! protocol_builder { #[allow(unused)] pub fn to_vec(&self) -> Vec { let mut vec = Vec::with_capacity(256); - let mut buf = $crate::protocol::writer::BufWriter::new(&mut vec); + let mut buf = $crate::BufWriter::new(&mut vec); self.copy_to_buf(&mut buf); match buf.finish() { Ok(size) => { @@ -654,7 +692,7 @@ macro_rules! protocol_builder { }, Err(size) => { vec.resize(size, 0); - let mut buf = $crate::protocol::writer::BufWriter::new(&mut vec); + let mut buf = $crate::BufWriter::new(&mut vec); self.copy_to_buf(&mut buf); // Will not fail this second time let size = buf.finish().unwrap(); @@ -668,14 +706,12 @@ macro_rules! protocol_builder { }; } -pub(crate) use {protocol, protocol_builder, r#if, struct_elaborate}; - #[cfg(test)] mod tests { use pretty_assertions::assert_eq; mod fixed_only { - protocol!( + crate::protocol!( struct FixedOnly { a: u8, } @@ -683,20 +719,20 @@ mod tests { } mod fixed_only_value { - protocol!(struct FixedOnlyValue { + crate::protocol!(struct FixedOnlyValue { a: u8 = 1, }); } mod mixed { - protocol!(struct Mixed { + crate::protocol!(struct Mixed { a: u8 = 1, s: ZTString, }); } mod docs { - protocol!( + crate::protocol!( /// Docs struct Docs { /// Docs @@ -708,7 +744,7 @@ mod tests { } mod length { - protocol!( + crate::protocol!( struct WithLength { a: u8, l: len, @@ -717,7 +753,7 @@ mod tests { } mod array { - protocol!( + crate::protocol!( struct StaticArray { a: u8, l: [u8; 4], @@ -726,7 +762,7 @@ mod tests { } mod string { - protocol!( + crate::protocol!( struct HasLString { s: LString, } @@ -735,7 +771,7 @@ mod tests { macro_rules! assert_stringify { (($($struct:tt)*), ($($expected:tt)*)) => { - struct_elaborate!(assert_stringify(__internal__ ($($expected)*)) => $($struct)*); + $crate::struct_elaborate!(assert_stringify(__internal__ ($($expected)*)) => $($struct)*); }; (__internal__ ($($expected:tt)*), $($struct:tt)*) => { // We don't want whitespace to impact this comparison @@ -791,7 +827,7 @@ mod tests { fixed(fixed_offset = fixed_offset, (0)), }, { - name(l), type (crate::protocol::meta::Length), size(fixed = fixed), + name(l), type (crate::meta::Length), size(fixed = fixed), value(auto = auto), docs(concat!("`", stringify! (l), "` field.")), fixed(fixed_offset = fixed_offset, ((0) + std::mem::size_of::())), }, @@ -807,7 +843,7 @@ mod tests { fixed(no_fixed_offset = no_fixed_offset, (0)), }, { - name(d), type ([u8; 4]), size(fixed = fixed), + name(d), type (crate::meta::FixedArray<4, u8>), size(fixed = fixed), value(no_value = no_value), docs(concat!("`", stringify! (d), "` field.")), fixed(no_fixed_offset = no_fixed_offset, ((0) + std::mem::size_of::())), @@ -817,7 +853,8 @@ mod tests { value(no_value = no_value), docs(concat!("`", stringify! (e), "` field.")), fixed(no_fixed_offset = no_fixed_offset, - (((0) + std::mem::size_of::()) + std::mem::size_of::<[u8; 4]>())), + (((0) + std::mem::size_of::()) + + std::mem::size_of::<[u8; 4]>())), }, ), })); diff --git a/rust/db_proto/src/lib.rs b/rust/db_proto/src/lib.rs new file mode 100644 index 00000000000..537ed319568 --- /dev/null +++ b/rust/db_proto/src/lib.rs @@ -0,0 +1,218 @@ +mod arrays; +mod buffer; +mod datatypes; +mod field_access; +mod gen; +mod message_group; +mod writer; + +#[doc(hidden)] +pub mod test_protocol; + +/// Metatypes for the protocol and related arrays/strings. +pub mod meta { + pub use super::arrays::meta::*; + pub use super::datatypes::meta::*; +} + +#[allow(unused)] +pub use arrays::{Array, ArrayIter, ZTArray, ZTArrayIter}; +pub use buffer::StructBuffer; +#[allow(unused)] +pub use datatypes::{Encoded, LString, Length, Rest, Uuid, ZTString}; +pub use field_access::{FieldAccess, FieldAccessArray, FixedSize}; +pub use writer::BufWriter; + +#[doc(inline)] +pub use gen::protocol; +#[doc(inline)] +pub use message_group::{match_message, message_group}; + +/// Re-export for the `protocol!` macro. +#[doc(hidden)] +pub use paste::paste; + +#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq, Eq)] +pub enum ParseError { + #[error("Buffer is too short")] + TooShort, + #[error("Invalid data")] + InvalidData, +} + +/// Implemented for all structs. +pub trait StructMeta { + type Struct<'a>: std::fmt::Debug; + fn new(buf: &[u8]) -> Result, ParseError>; + fn to_vec(s: &Self::Struct<'_>) -> Vec; +} + +/// Implemented for all generated structs that have a [`meta::Length`] field at a fixed offset. +pub trait StructLength: StructMeta { + fn length_field_of(of: &Self::Struct<'_>) -> usize; + fn length_field_offset() -> usize; + fn length_of_buf(buf: &[u8]) -> Option { + if buf.len() < Self::length_field_offset() + std::mem::size_of::() { + None + } else { + let len = FieldAccess::::extract( + &buf[Self::length_field_offset() + ..Self::length_field_offset() + std::mem::size_of::()], + ) + .ok()?; + Some(Self::length_field_offset() + len) + } + } +} + +/// For a given metaclass, returns the inflated type, a measurement type and a +/// builder type. +/// +/// Types that don't include a lifetime can use the same type for the meta type +/// and the `WithLifetime` type. +pub trait Enliven { + type WithLifetime<'a>; + type ForMeasure<'a>: 'a; + type ForBuilder<'a>: 'a; +} + +#[derive(Debug, Eq, PartialEq)] +pub enum MetaRelation { + Parent, + Length, + Item, + Field(&'static str), +} + +pub trait Meta { + fn name(&self) -> &'static str { + std::any::type_name::() + } + fn relations(&self) -> &'static [(MetaRelation, &'static dyn Meta)] { + &[] + } + fn fixed_length(&self) -> Option { + None + } + fn field(&self, name: &'static str) -> Option<&'static dyn Meta> { + for (relation, meta) in self.relations() { + if relation == &MetaRelation::Field(name) { + return Some(*meta); + } + } + None + } + fn parent(&self) -> Option<&'static dyn Meta> { + for (relation, meta) in self.relations() { + if relation == &MetaRelation::Parent { + return Some(*meta); + } + } + None + } +} + +impl PartialEq for dyn Meta { + fn eq(&self, other: &T) -> bool { + other.name() == self.name() + } +} + +impl std::fmt::Debug for dyn Meta { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut s = f.debug_struct(self.name()); + if let Some(length) = self.fixed_length() { + s.field("Length", &length); + } + for (relation, meta) in self.relations() { + if relation == &MetaRelation::Parent { + s.field(&format!("{relation:?}"), &meta.name()); + } else { + s.field(&format!("{relation:?}"), meta); + } + } + s.finish() + } +} + +/// Used internally by the `protocol!` macro to copy from `FieldAccess` in this crate to +/// `FieldAccess` in the generated code. +#[macro_export] +#[doc(hidden)] +macro_rules! field_access_copy { + ($acc1:ident :: FieldAccess, $acc2:ident :: FieldAccess, $($ty:ty),*) => { + $( + $crate::field_access_copy!(: $acc1 :: FieldAccess, $acc2 :: FieldAccess, + $ty, + $crate::meta::ZTArray<$ty>, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array + ); + )* + }; + + (basic $acc1:ident :: FieldAccess, $acc2:ident :: FieldAccess, $($ty:ty),*) => { + $( + + $crate::field_access_copy!(: $acc1 :: FieldAccess, $acc2 :: FieldAccess, + $ty, + $crate::meta::ZTArray<$ty>, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array + ); + + impl $acc2 :: FieldAccess<$crate::meta::FixedArray> { + #[inline(always)] + pub const fn meta() -> &'static dyn $crate::Meta { + $acc1::FieldAccess::<$crate::meta::FixedArray>::meta() + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + $acc1::FieldAccess::<$crate::meta::FixedArray>::size_of_field_at(buf) + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> Result<[<$ty as $crate::Enliven>::WithLifetime<'_>; S], $crate::ParseError> { + $acc1::FieldAccess::<$crate::meta::FixedArray>::extract(buf) + } + pub const fn constant(_: usize) -> $ty { + panic!("Constants unsupported for this data type") + } + #[inline(always)] + pub const fn measure(value: &[<$ty as $crate::Enliven>::ForMeasure<'_>; S]) -> usize { + $acc1::FieldAccess::<$crate::meta::FixedArray>::measure(value) + } + } + )* + }; + (: $acc1:ident :: FieldAccess, $acc2:ident :: FieldAccess, $($ty:ty),*) => { + $( + impl $acc2 :: FieldAccess<$ty> { + #[inline(always)] + pub const fn meta() -> &'static dyn $crate::Meta { + $acc1::FieldAccess::<$ty>::meta() + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + $acc1::FieldAccess::<$ty>::size_of_field_at(buf) + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> Result<<$ty as $crate::Enliven>::WithLifetime<'_>, $crate::ParseError> { + $acc1::FieldAccess::<$ty>::extract(buf) + } + pub const fn constant(value: usize) -> <$ty as $crate::Enliven>::WithLifetime<'static> { + $acc1::FieldAccess::<$ty>::constant(value) + } + #[inline(always)] + pub const fn measure(value: &<$ty as $crate::Enliven>::ForMeasure<'_>) -> usize { + $acc1::FieldAccess::<$ty>::measure(value) + } + } + )* + }; +} diff --git a/rust/pgrust/src/protocol/message_group.rs b/rust/db_proto/src/message_group.rs similarity index 84% rename from rust/pgrust/src/protocol/message_group.rs rename to rust/db_proto/src/message_group.rs index 0f5f0720857..a88fade9d59 100644 --- a/rust/pgrust/src/protocol/message_group.rs +++ b/rust/db_proto/src/message_group.rs @@ -1,6 +1,8 @@ -macro_rules! message_group { +#[doc(hidden)] +#[macro_export] +macro_rules! __message_group { ($(#[$doc:meta])* $group:ident : $super:ident = [$($message:ident),*]) => { - paste::paste!( + $crate::paste!( $(#[$doc])* #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[allow(unused)] @@ -21,7 +23,7 @@ macro_rules! message_group { #[allow(unused)] impl [<$group Builder>]<'_> { - pub fn to_vec(&self) -> Vec { + pub fn to_vec(self) -> Vec { match self { $( Self::$message(message) => message.to_vec(), @@ -29,7 +31,7 @@ macro_rules! message_group { } } - pub fn copy_to_buf(&self, writer: &mut $crate::protocol::writer::BufWriter) { + pub fn copy_to_buf(&self, writer: &mut $crate::BufWriter) { match self { $( Self::$message(message) => message.copy_to_buf(writer), @@ -65,7 +67,7 @@ macro_rules! message_group { impl $group { pub fn identify(buf: &[u8]) -> Option { $( - if ::WithLifetime::is_buffer(buf) { + if ::WithLifetime::is_buffer(buf) { return Some(Self::$message); } )* @@ -76,17 +78,19 @@ macro_rules! message_group { ); }; } -pub(crate) use message_group; + +#[doc(inline)] +pub use __message_group as message_group; /// Perform a match on a message. /// /// ```rust -/// use pgrust::protocol::*; -/// use pgrust::protocol::postgres::data::*; +/// use db_proto::*; +/// use db_proto::test_protocol::data::*; /// /// let buf = [b'?', 0, 0, 0, 4]; /// match_message!(Message::new(&buf), Backend { -/// (BackendKeyData as data) => { +/// (DataRow as data) => { /// todo!(); /// }, /// unknown => { @@ -102,7 +106,7 @@ macro_rules! __match_message { $unknown:ident => $unknown_impl:block $(,)? }) => { 'block: { - let __message: Result<_, $crate::protocol::ParseError> = $buf; + let __message: Result<_, $crate::ParseError> = $buf; let res = match __message { Ok(__message) => { $( @@ -138,18 +142,15 @@ pub use __match_message as match_message; #[cfg(test)] mod tests { use super::*; - use crate::protocol::postgres::{ - builder, - data::{Message, PasswordMessage}, - }; + use crate::test_protocol::{builder, data::*}; #[test] fn test_match() { let message = builder::Sync::default().to_vec(); let message = Message::new(&message); match_message!(message, Message { - (PasswordMessage as password) => { - eprintln!("{password:?}"); + (DataRow as data_row) => { + eprintln!("{data_row:?}"); return; }, unknown => { diff --git a/rust/db_proto/src/test_protocol.rs b/rust/db_proto/src/test_protocol.rs new file mode 100644 index 00000000000..344eaeed63a --- /dev/null +++ b/rust/db_proto/src/test_protocol.rs @@ -0,0 +1,140 @@ +//! A pseudo-Postgres protocol for testing. +use crate::gen::protocol; + +protocol!( + struct Message { + /// The message type. + mtype: u8, + /// The length of the message contents in bytes, including self. + mlen: len, + /// The message contents. + data: Rest, + } + + /// The `CommandComplete` struct represents a message indicating the successful completion of a command. + struct CommandComplete: Message { + /// Identifies the message as a command-completed response. + mtype: u8 = 'C', + /// Length of message contents in bytes, including self. + mlen: len, + /// The command tag. + tag: ZTString, + } + + /// The `Sync` message is used to synchronize the client and server. + struct Sync: Message { + /// Identifies the message as a synchronization request. + mtype: u8 = 'S', + /// Length of message contents in bytes, including self. + mlen: len, + } + + /// The `DataRow` message represents a row of data returned from a query. + struct DataRow: Message { + /// Identifies the message as a data row. + mtype: u8 = 'D', + /// Length of message contents in bytes, including self. + mlen: len, + /// The values in the row. + values: Array, + } + + struct QueryType { + /// The type of the query parameter. + typ: u8, + /// The length of the query parameter. + len: u32, + /// The metadata of the query parameter. + meta: Array, + } + + struct Query: Message { + /// Identifies the message as a query. + mtype: u8 = 'Q', + /// Length of message contents in bytes, including self. + mlen: len, + /// The query string. + query: ZTString, + /// The types of the query parameters. + types: Array, + } + + struct Key { + /// The key. + key: [u8; 16], + } + + struct Uuids { + /// The UUIDs. + uuids: Array, + } +); + +#[cfg(test)] +mod tests { + use uuid::Uuid; + + use super::*; + + #[test] + fn test_meta() { + let expected = [ + r#"Message { Field("mtype"): u8, Field("mlen"): len, Field("data"): Rest }"#, + r#"CommandComplete { Parent: "Message", Field("mtype"): u8, Field("mlen"): len, Field("tag"): ZTString }"#, + r#"Sync { Parent: "Message", Field("mtype"): u8, Field("mlen"): len }"#, + r#"DataRow { Parent: "Message", Field("mtype"): u8, Field("mlen"): len, Field("values"): Array { Length: i16, Item: Encoded } }"#, + r#"QueryType { Field("typ"): u8, Field("len"): u32, Field("meta"): Array { Length: u32, Item: u8 } }"#, + r#"Query { Parent: "Message", Field("mtype"): u8, Field("mlen"): len, Field("query"): ZTString, Field("types"): Array { Length: i16, Item: QueryType { Field("typ"): u8, Field("len"): u32, Field("meta"): Array { Length: u32, Item: u8 } } } }"#, + r#"Key { Field("key"): FixedArray { Length: 16, Item: u8 } }"#, + r#"Uuids { Field("uuids"): Array { Length: u32, Item: Uuid } }"#, + ]; + + for (i, meta) in meta::ALL.iter().enumerate() { + assert_eq!(expected[i], format!("{meta:?}")); + } + } + + #[test] + fn test_query() { + let buf = builder::Query { + query: "SELECT * from foo", + types: &[builder::QueryType { + typ: 1, + len: 4, + meta: &[1, 2, 3, 4], + }], + } + .to_vec(); + + let query = data::Query::new(&buf).expect("Failed to parse query"); + assert_eq!( + r#"Query { mtype: 81, mlen: 37, query: "SELECT * from foo", types: [QueryType { typ: 1, len: 4, meta: [1, 2, 3, 4] }] }"#, + format!("{query:?}") + ); + } + + #[test] + fn test_fixed_array() { + let buf = builder::Key { + key: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + } + .to_vec(); + + let key = data::Key::new(&buf).expect("Failed to parse key"); + assert_eq!( + key.key(), + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + ); + } + + #[test] + fn test_uuid() { + let buf = builder::Uuids { + uuids: &[Uuid::NAMESPACE_DNS], + } + .to_vec(); + + let uuids = data::Uuids::new(&buf).expect("Failed to parse uuids"); + assert_eq!(uuids.uuids().get(0), Some(Uuid::NAMESPACE_DNS)); + } +} diff --git a/rust/pgrust/src/protocol/writer.rs b/rust/db_proto/src/writer.rs similarity index 100% rename from rust/pgrust/src/protocol/writer.rs rename to rust/db_proto/src/writer.rs diff --git a/rust/pgrust/Cargo.toml b/rust/pgrust/Cargo.toml index a0eaf7acf88..e67df1a7da9 100644 --- a/rust/pgrust/Cargo.toml +++ b/rust/pgrust/Cargo.toml @@ -17,6 +17,7 @@ gel_auth.workspace = true pyo3.workspace = true tokio.workspace = true tracing.workspace = true +db_proto.workspace = true futures = "0" thiserror = "1" @@ -30,7 +31,7 @@ url = "2" serde = "1" serde_derive = "1" percent-encoding = "2" -uuid = "1" +bytemuck = { version = "1", features = ["derive"] } [dependencies.derive_more] version = "1.0.0-beta.6" @@ -38,19 +39,15 @@ features = ["full"] [dev-dependencies] tracing-subscriber.workspace = true -scopeguard = "1" +captive_postgres.workspace = true +scopeguard = "1" pretty_assertions = "1.2.0" test-log = { version = "0", features = ["trace"] } rstest = "0" clap = "4" clap_derive = "4" -tempfile = "3" -socket2 = "0.5.7" libc = "0.2.158" - -[dev-dependencies.tokio] -version = "1" -features = ["macros", "rt-multi-thread", "time", "test-util"] +hex-literal = "0.4.1" [lib] diff --git a/rust/pgrust/examples/connect.rs b/rust/pgrust/examples/connect.rs index bb26dafddcc..d23faaefbe4 100644 --- a/rust/pgrust/examples/connect.rs +++ b/rust/pgrust/examples/connect.rs @@ -1,9 +1,16 @@ +use captive_postgres::{ + setup_postgres, ListenAddress, Mode, DEFAULT_DATABASE, DEFAULT_PASSWORD, DEFAULT_USERNAME, +}; use clap::Parser; use clap_derive::Parser; +use gel_auth::AuthType; use openssl::ssl::{Ssl, SslContext, SslMethod}; use pgrust::{ - connection::{dsn::parse_postgres_dsn_env, Client, Credentials, ResolvedTarget}, - protocol::postgres::data::{DataRow, ErrorResponse, RowDescription}, + connection::{ + dsn::parse_postgres_dsn_env, Client, Credentials, ExecuteSink, Format, MaxRows, + PipelineBuilder, Portal, QuerySink, ResolvedTarget, Statement, + }, + protocol::postgres::data::{CopyData, CopyOutResponse, DataRow, ErrorResponse, RowDescription}, }; use std::net::SocketAddr; use tokio::task::LocalSet; @@ -11,6 +18,10 @@ use tokio::task::LocalSet; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { + /// Use an ephemeral database + #[clap(short = 'e', long = "ephemeral", conflicts_with_all = &["dsn", "unix", "tcp", "username", "password", "database"])] + ephemeral: bool, + #[clap(short = 'D', long = "dsn", value_parser, conflicts_with_all = &["unix", "tcp", "username", "password", "database"])] dsn: Option, @@ -44,6 +55,10 @@ struct Args { )] database: String, + /// Use extended query syntax + #[clap(short = 'x', long = "extended")] + extended: bool, + /// SQL statements to run #[clap( name = "statements", @@ -54,6 +69,16 @@ struct Args { statements: Option>, } +fn address(address: &ListenAddress) -> ResolvedTarget { + match address { + ListenAddress::Tcp(addr) => ResolvedTarget::SocketAddr(*addr), + #[cfg(unix)] + ListenAddress::Unix(path) => ResolvedTarget::UnixSocketAddr( + std::os::unix::net::SocketAddr::from_pathname(path).unwrap(), + ), + } +} + #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); @@ -61,6 +86,22 @@ async fn main() -> Result<(), Box> { eprintln!("{args:?}"); let mut socket_address: Option = None; + + let _ephemeral = if args.ephemeral { + let process = setup_postgres(AuthType::Trust, Mode::Unix)?; + let Some(process) = process else { + eprintln!("Failed to start ephemeral database"); + return Err("Failed to start ephemeral database".into()); + }; + socket_address = Some(address(&process.socket_address)); + args.username = DEFAULT_USERNAME.to_string(); + args.password = DEFAULT_PASSWORD.to_string(); + args.database = DEFAULT_DATABASE.to_string(); + Some(process) + } else { + None + }; + if let Some(dsn) = args.dsn { let mut conn = parse_postgres_dsn_env(&dsn, std::env::vars())?; #[allow(deprecated)] @@ -97,16 +138,96 @@ async fn main() -> Result<(), Box> { .unwrap_or_else(|| vec!["select 1;".to_string()]); let local = LocalSet::new(); local - .run_until(run_queries(socket_address, credentials, statements)) + .run_until(run_queries( + socket_address, + credentials, + statements, + args.extended, + )) .await?; Ok(()) } +fn logging_sink() -> impl QuerySink { + ( + |rows: RowDescription<'_>| { + eprintln!("\nFields:"); + for field in rows.fields() { + eprint!(" {:?}", field.name()); + } + eprintln!(); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |row: DataRow<'_>| { + let _ = &guard; + eprintln!("Row:"); + for field in row.values() { + eprint!(" {:?}", field); + } + eprintln!(); + } + }, + |_: CopyOutResponse<'_>| { + eprintln!("\nCopy:"); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |data: CopyData<'_>| { + let _ = &guard; + eprintln!("Chunk:"); + for line in hexdump::hexdump_iter(data.data().as_ref()) { + eprintln!("{line}"); + } + } + }, + |error: ErrorResponse<'_>| { + eprintln!("\nError:\n {:?}", error); + }, + ) +} + +fn logging_sink_execute() -> impl ExecuteSink { + ( + || { + eprintln!(); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |row: DataRow<'_>| { + let _ = &guard; + eprintln!("Row:"); + for field in row.values() { + eprint!(" {:?}", field); + } + eprintln!(); + } + }, + |_: CopyOutResponse<'_>| { + eprintln!("\nCopy:"); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |data: CopyData<'_>| { + let _ = &guard; + eprintln!("Chunk:"); + for line in hexdump::hexdump_iter(data.data().as_ref()) { + eprintln!("{line}"); + } + } + }, + |error: ErrorResponse<'_>| { + eprintln!("\nError:\n {:?}", error); + }, + ) +} + async fn run_queries( socket_address: ResolvedTarget, credentials: Credentials, statements: Vec, + extended: bool, ) -> Result<(), Box> { let client = socket_address.connect().await?; let ssl = SslContext::builder(SslMethod::tls_client())?.build(); @@ -116,37 +237,36 @@ async fn run_queries( tokio::task::spawn_local(task); conn.ready().await?; - let local = LocalSet::new(); eprintln!("Statements: {statements:?}"); + for statement in statements { - let sink = ( - |rows: RowDescription<'_>| { - eprintln!("\nFields:"); - for field in rows.fields() { - eprint!(" {:?}", field.name()); - } - eprintln!(); - let guard = scopeguard::guard((), |_| { - eprintln!("Done"); - }); - move |row: Result, ErrorResponse<'_>>| { - let _ = &guard; - if let Ok(row) = row { - eprintln!("Row:"); - for field in row.values() { - eprint!(" {:?}", field); - } - eprintln!(); - } - } - }, - |error: ErrorResponse<'_>| { - eprintln!("\nError:\n {:?}", error); - }, - ); - local.spawn_local(conn.query(&statement, sink)); + if extended { + let conn = conn.clone(); + tokio::task::spawn_local(async move { + let pipeline = PipelineBuilder::default() + .parse(Statement::default(), &statement, &[], ()) + .describe_statement(Statement::default(), ()) + .bind( + Portal::default(), + Statement::default(), + &[], + &[Format::text()], + (), + ) + .describe_portal(Portal::default(), ()) + .execute( + Portal::default(), + MaxRows::Unlimited, + logging_sink_execute(), + ) + .build(); + conn.pipeline_sync(pipeline).await + }) + .await??; + } else { + tokio::task::spawn_local(conn.query(&statement, logging_sink())).await??; + } } - local.await; Ok(()) } diff --git a/rust/pgrust/src/connection/conn.rs b/rust/pgrust/src/connection/conn.rs index d26f1f69530..05be12d3c91 100644 --- a/rust/pgrust/src/connection/conn.rs +++ b/rust/pgrust/src/connection/conn.rs @@ -1,27 +1,27 @@ use super::{ connect_raw_ssl, + flow::{MessageHandler, MessageResult, Pipeline, QuerySink}, raw_conn::RawClient, stream::{Stream, StreamWithUpgrade}, Credentials, }; use crate::{ - connection::ConnectionError, + connection::{ + flow::{QueryMessageHandler, SyncMessageHandler}, + ConnectionError, + }, handshake::ConnectionSslRequirement, - protocol::{ - match_message, - postgres::{ - builder, - data::{ - CommandComplete, DataRow, ErrorResponse, Message, ReadyForQuery, RowDescription, - }, - meta, - }, - StructBuffer, + protocol::postgres::{ + builder, + data::{Message, NotificationResponse, ParameterStatus}, + meta, }, }; -use futures::FutureExt; +use db_proto::StructBuffer; +use futures::{future::Either, FutureExt}; use std::{ cell::RefCell, + future::ready, pin::Pin, sync::Arc, task::{ready, Poll}, @@ -35,17 +35,47 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{error, trace, warn, Level}; #[derive(Debug, thiserror::Error)] -pub enum PGError { +pub enum PGConnError { #[error("Invalid state")] InvalidState, + #[error("Postgres error: {0}")] + PgError(#[from] crate::errors::PgServerError), #[error("Connection failed: {0}")] Connection(#[from] ConnectionError), #[error("I/O error: {0}")] Io(#[from] std::io::Error), + /// If an operation in a pipeline group fails, all operations up to + /// the next sync are skipped. + #[error("Operation skipped because of previous pipeline failure: {0}")] + Skipped(crate::errors::PgServerError), #[error("Connection was closed")] Closed, } +/// A client for a PostgreSQL connection. +/// +/// ``` +/// # use pgrust::connection::*; +/// # _ = async { +/// # let config = (); +/// # let credentials = Credentials::default(); +/// # let (client, server) = ::tokio::io::duplex(64); +/// # let socket = client; +/// let (client, task) = Client::new(credentials, socket, config); +/// ::tokio::task::spawn_local(task); +/// +/// // Run a basic query +/// client.query("SELECT 1", ()).await?; +/// +/// // Run a pipelined extended query +/// client.pipeline_sync(PipelineBuilder::default() +/// .parse(Statement("stmt1"), "SELECT 1", &[], ()) +/// .bind(Portal("portal1"), Statement("stmt1"), &[], &[Format::text()], ()) +/// .execute(Portal("portal1"), MaxRows::Unlimited, ()) +/// .build()).await?; +/// # Ok::<(), PGConnError>(()) +/// # } +/// ``` pub struct Client where (B, C): StreamWithUpgrade, @@ -53,6 +83,17 @@ where conn: Rc>, } +impl Clone for Client +where + (B, C): StreamWithUpgrade, +{ + fn clone(&self) -> Self { + Self { + conn: self.conn.clone(), + } + } +} + impl Client where (B, C): StreamWithUpgrade, @@ -63,7 +104,7 @@ where credentials: Credentials, socket: B, config: C, - ) -> (Self, impl Future>) { + ) -> (Self, impl Future>) { let conn = Rc::new(PGConn::new_connection(async move { let ssl_mode = ConnectionSslRequirement::Optional; let raw = connect_raw_ssl(credentials, ssl_mode, config, socket).await?; @@ -74,107 +115,61 @@ where } /// Create a new PostgreSQL client and a background task. - pub fn new_raw(stm: RawClient) -> (Self, impl Future>) { + pub fn new_raw(stm: RawClient) -> (Self, impl Future>) { let conn = Rc::new(PGConn::new_raw(stm)); let task = conn.clone().task(); (Self { conn }, task) } - pub async fn ready(&self) -> Result<(), PGError> { + pub async fn ready(&self) -> Result<(), PGConnError> { self.conn.ready().await } + /// Performs a bare `Query` operation. The sink handles the following messages: + /// + /// - `RowDescription` + /// - `DataRow` + /// - `CopyOutResponse` + /// - `CopyData` + /// - `CopyDone` + /// - `EmptyQueryResponse` + /// - `ErrorResponse` + /// + /// `CopyInResponse` is not currently supported and will result in a `CopyFail` being + /// sent to the server. + /// + /// Cancellation safety: if the future is dropped after the first time it is polled, the operation will + /// continue to callany callbacks and run to completion. If it has not been polled, the operation will + /// not be submitted. pub fn query( &self, query: &str, f: impl QuerySink + 'static, - ) -> impl Future> { - self.conn.clone().query(query.to_owned(), f) - } -} - -struct ErasedQuerySink(Q); - -impl QuerySink for ErasedQuerySink -where - Q: QuerySink, - S: DataSink + 'static, -{ - type Output = Box; - fn error(&self, error: ErrorResponse) { - self.0.error(error) - } - fn rows(&self, rows: RowDescription) -> Self::Output { - Box::new(self.0.rows(rows)) - } -} - -pub trait QuerySink { - type Output: DataSink; - fn rows(&self, rows: RowDescription) -> Self::Output; - fn error(&self, error: ErrorResponse); -} - -impl QuerySink for Box -where - Q: QuerySink + 'static, - S: DataSink + 'static, -{ - type Output = Box; - fn rows(&self, rows: RowDescription) -> Self::Output { - Box::new(self.as_ref().rows(rows)) - } - fn error(&self, error: ErrorResponse) { - self.as_ref().error(error) - } -} - -impl QuerySink for (F1, F2) -where - F1: for<'a> Fn(RowDescription) -> S, - F2: for<'a> Fn(ErrorResponse), - S: DataSink, -{ - type Output = S; - fn rows(&self, rows: RowDescription) -> S { - (self.0)(rows) - } - fn error(&self, error: ErrorResponse) { - (self.1)(error) - } -} - -pub trait DataSink { - fn row(&self, values: Result); -} - -impl DataSink for () { - fn row(&self, _: Result) {} -} - -impl DataSink for F -where - F: for<'a> Fn(Result, ErrorResponse<'a>>), -{ - fn row(&self, values: Result) { - (self)(values) + ) -> impl Future> { + match self.conn.clone().query(query, f) { + Ok(f) => Either::Left(f), + Err(e) => Either::Right(ready(Err(e))), + } } -} -impl DataSink for Box { - fn row(&self, values: Result) { - self.as_ref().row(values) + /// Performs a set of pipelined steps as a `Sync` group. + /// + /// Cancellation safety: if the future is dropped after the first time it is polled, the operation will + /// continue to callany callbacks and run to completion. If it has not been polled, the operation will + /// not be submitted. + pub fn pipeline_sync( + &self, + pipeline: Pipeline, + ) -> impl Future> { + match self.conn.clone().pipeline_sync(pipeline) { + Ok(f) => Either::Left(f), + Err(e) => Either::Right(ready(Err(e))), + } } } -struct QueryWaiter { - #[allow(unused)] - tx: tokio::sync::mpsc::UnboundedSender<()>, - f: Box>>, - data: RefCell>>, -} - #[derive(derive_more::Debug)] +#[allow(clippy::type_complexity)] enum ConnState where (B, C): StreamWithUpgrade, @@ -183,8 +178,14 @@ where #[allow(clippy::type_complexity)] Connecting(Pin, ConnectionError>>>>), #[debug("Ready(..)")] - Ready(RawClient, VecDeque), - Error(PGError), + Ready { + client: RawClient, + handlers: VecDeque<( + Box, + Option>, + )>, + }, + Error(PGConnError), Closed, } @@ -193,33 +194,39 @@ where (B, C): StreamWithUpgrade, { state: RefCell>, - write_lock: tokio::sync::Mutex<()>, + queue: RefCell>>, ready_lock: Arc>, } impl PGConn where (B, C): StreamWithUpgrade, + B: 'static, + C: 'static, { pub fn new_connection( future: impl Future, ConnectionError>> + 'static, ) -> Self { Self { state: ConnState::Connecting(future.boxed_local()).into(), - write_lock: Default::default(), + queue: Default::default(), ready_lock: Default::default(), } } pub fn new_raw(stm: RawClient) -> Self { Self { - state: ConnState::Ready(stm, Default::default()).into(), - write_lock: Default::default(), + state: ConnState::Ready { + client: stm, + handlers: Default::default(), + } + .into(), + queue: Default::default(), ready_lock: Default::default(), } } - fn check_error(&self) -> Result<(), PGError> { + fn check_error(&self) -> Result<(), PGConnError> { let state = &mut *self.state.borrow_mut(); match state { ConnState::Error(..) => { @@ -229,97 +236,146 @@ where error!("Connection failed: {e:?}"); Err(e) } - ConnState::Closed => Err(PGError::Closed), + ConnState::Closed => Err(PGConnError::Closed), _ => Ok(()), } } #[inline(always)] - async fn ready(&self) -> Result<(), PGError> { + async fn ready(&self) -> Result<(), PGConnError> { let _ = self.ready_lock.lock().await; self.check_error() } - fn with_stream(&self, f: F) -> Result + fn with_stream(&self, f: F) -> Result where F: FnOnce(Pin<&mut RawClient>) -> T, { match &mut *self.state.borrow_mut() { - ConnState::Ready(ref mut raw_client, _) => Ok(f(Pin::new(raw_client))), - _ => Err(PGError::InvalidState), + ConnState::Ready { ref mut client, .. } => Ok(f(Pin::new(client))), + _ => Err(PGConnError::InvalidState), } } - async fn write(&self, mut buf: &[u8]) -> Result<(), PGError> { - let _lock = self.write_lock.lock().await; + fn write( + self: Rc, + message_handlers: Vec>, + buf: Vec, + ) -> Result, PGConnError> { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.clone().queue.borrow_mut().submit(async move { + // If the future was dropped before the first poll, we don't submit the operation + if tx.is_closed() { + return Ok(()); + } - if buf.is_empty() { - return Ok(()); - } - if tracing::enabled!(Level::TRACE) { - trace!("Write:"); - for s in hexdump::hexdump_iter(buf) { - trace!("{}", s); + // Once we're polled the first time, we can add the handlers + match &mut *self.state.borrow_mut() { + ConnState::Ready { handlers, .. } => { + let mut handlers_iter = message_handlers.into_iter(); + let mut tx = Some(tx); + while let Some(handler) = handlers_iter.next() { + if handlers_iter.len() == 0 { + handlers.push_back((handler, tx.take())); + } else { + handlers.push_back((handler, None)); + } + } + } + x => { + warn!("Connection state was not ready: {x:?}"); + return Err(PGConnError::InvalidState); + } } - } - loop { - let n = poll_fn(|cx| { - self.with_stream(|stm| { - let n = match ready!(stm.poll_write(cx, buf)) { - Ok(n) => n, - Err(e) => return Poll::Ready(Err(PGError::Io(e))), - }; - Poll::Ready(Ok(n)) - })? - }) - .await?; - if n == buf.len() { - break; + + if tracing::enabled!(Level::TRACE) { + trace!("Write:"); + for s in hexdump::hexdump_iter(&buf) { + trace!("{}", s); + } } - buf = &buf[n..]; - } - Ok(()) + + let mut buf = &buf[..]; + + loop { + let n = poll_fn(|cx| { + self.with_stream(|stm| { + let n = match ready!(stm.poll_write(cx, buf)) { + Ok(n) => n, + Err(e) => return Poll::Ready(Err(PGConnError::Io(e))), + }; + Poll::Ready(Ok(n)) + })? + }) + .await?; + if n == buf.len() { + break; + } + buf = &buf[n..]; + } + + Ok(()) + }); + + Ok(rx) } - fn process_message(&self, message: Option) -> Result<(), PGError> { + fn process_message(&self, message: Option) -> Result<(), PGConnError> { let state = &mut *self.state.borrow_mut(); match state { - ConnState::Ready(_, queue) => { - let message = message.ok_or(PGError::InvalidState); - match_message!(Ok(message?), Backend { - (RowDescription as row) => { - if let Some(qw) = queue.back() { - let qs = qw.f.rows(row); - *qw.data.borrow_mut() = Some(qs); - } - }, - (DataRow as row) => { - if let Some(qw) = queue.back() { - if let Some(qs) = &*qw.data.borrow() { - qs.row(Ok(row)) + ConnState::Ready { handlers, .. } => { + let message = message.ok_or(PGConnError::InvalidState)?; + if NotificationResponse::try_new(&message).is_some() { + warn!("Notification: {:?}", message); + return Ok(()); + } + if ParameterStatus::try_new(&message).is_some() { + warn!("ParameterStatus: {:?}", message); + return Ok(()); + } + if let Some((handler, _tx)) = handlers.front_mut() { + match handler.handle(message) { + MessageResult::SkipUntilSync => { + let mut found_sync = false; + let name = handler.name(); + while let Some((handler, _)) = handlers.front() { + if handler.is_sync() { + found_sync = true; + break; + } + trace!("skipping {}", handler.name()); + handlers.pop_front(); + } + if !found_sync { + warn!("Unexpected state in {name}: No sync handler found"); } } - }, - (CommandComplete) => { - if let Some(qw) = queue.back() { - *qw.data.borrow_mut() = None; + MessageResult::Continue => {} + MessageResult::Done => { + handlers.pop_front(); } - }, - (ReadyForQuery) => { - queue.pop_front(); - }, - (ErrorResponse as err) => { - if let Some(qw) = queue.back() { - qw.f.error(err); + MessageResult::Unknown => { + // TODO: Should the be exposed to the API consumer? + warn!( + "Unknown message in {} ({:?})", + handler.name(), + message.mtype() as char + ); } - }, - unknown => { - eprintln!("Unknown message: {unknown:?}"); - } - }); + MessageResult::UnexpectedState { complaint } => { + // TODO: Should the be exposed to the API consumer? + warn!( + "Unexpected state in {} while handling message ({:?}): {complaint}", + handler.name(), + message.mtype() as char + ); + } + }; + }; } ConnState::Connecting(..) => { - return Err(PGError::InvalidState); + return Err(PGConnError::InvalidState); } ConnState::Error(..) | ConnState::Closed => self.check_error()?, } @@ -327,9 +383,8 @@ where Ok(()) } - pub fn task(self: Rc) -> impl Future> { + pub fn task(self: Rc) -> impl Future> { let ready_lock = self.ready_lock.clone().try_lock_owned().unwrap(); - async move { poll_fn(|cx| { let mut state = self.state.borrow_mut(); @@ -339,17 +394,20 @@ where let raw = match result { Ok(raw) => raw, Err(e) => { - let error = PGError::Connection(e); + let error = PGConnError::Connection(e); *state = ConnState::Error(error); - return Poll::Ready(Ok::<_, PGError>(())); + return Poll::Ready(Ok::<_, PGConnError>(())); } }; - *state = ConnState::Ready(raw, VecDeque::new()); - Poll::Ready(Ok::<_, PGError>(())) + *state = ConnState::Ready { + client: raw, + handlers: Default::default(), + }; + Poll::Ready(Ok::<_, PGConnError>(())) } Poll::Pending => Poll::Pending, }, - ConnState::Ready(..) => Poll::Ready(Ok(())), + ConnState::Ready { .. } => Poll::Ready(Ok(())), ConnState::Error(..) | ConnState::Closed => Poll::Ready(self.check_error()), } }) @@ -361,10 +419,15 @@ where loop { let mut read_buffer = [0; 1024]; let n = poll_fn(|cx| { + // Poll the queue before we poll the read stream. Note that we toss + // the result here. Either we'll make progress or there's nothing to + // do. + while self.queue.borrow_mut().poll_next_unpin(cx).is_ready() {} + self.with_stream(|stm| { let mut buf = ReadBuf::new(&mut read_buffer); let res = ready!(stm.poll_read(cx, &mut buf)); - Poll::Ready(res.map(|_| buf.filled().len())).map_err(PGError::Io) + Poll::Ready(res.map(|_| buf.filled().len())).map_err(PGConnError::Io) })? }) .await?; @@ -377,6 +440,14 @@ where } buffer.push_fallible(&read_buffer[..n], |message| { + if let Ok(message) = &message { + if tracing::enabled!(Level::TRACE) { + trace!("Message ({:?})", message.mtype() as char); + for s in hexdump::hexdump_iter(message.__buf) { + trace!("{}", s); + } + } + }; self.process_message(Some(message.map_err(ConnectionError::ParseError)?)) })?; @@ -388,35 +459,526 @@ where } } - pub async fn query( + pub fn query( self: Rc, - query: String, + query: &str, f: impl QuerySink + 'static, - ) -> Result<(), PGError> { + ) -> Result>, PGConnError> { trace!("Query task started: {query}"); - let mut rx = match &mut *self.state.borrow_mut() { - ConnState::Ready(_, queue) => { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let f = Box::new(ErasedQuerySink(f)) as _; - queue.push_back(QueryWaiter { - tx, - f, - data: None.into(), - }); - rx + let message = builder::Query { query }.to_vec(); + let rx = self.write( + vec![Box::new(QueryMessageHandler { + sink: f, + data: None, + copy: None, + })], + message, + )?; + Ok(async { + _ = rx.await; + Ok(()) + }) + } + + pub fn pipeline_sync( + self: Rc, + pipeline: Pipeline, + ) -> Result>, PGConnError> { + trace!("Pipeline task started"); + let Pipeline { + mut messages, + mut handlers, + } = pipeline; + handlers.push(Box::new(SyncMessageHandler)); + messages.extend_from_slice(&builder::Sync::default().to_vec()); + + let rx = self.write(handlers, messages)?; + Ok(async { + _ = rx.await; + Ok(()) + }) + } +} + +#[cfg(test)] +mod tests { + use hex_literal::hex; + use std::{fmt::Write, time::Duration}; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt, DuplexStream}, + task::LocalSet, + time::timeout, + }; + + use crate::connection::{ + flow::{CopyDataSink, DataSink, DoneHandling}, + raw_conn::ConnectionParams, + }; + use crate::protocol::postgres::data::*; + + use super::*; + + impl QuerySink for Rc> { + type Output = Self; + type CopyOutput = Self; + fn rows(&mut self, rows: RowDescription) -> Self { + write!(self.borrow_mut(), "[table=[").unwrap(); + for field in rows.fields() { + write!(self.borrow_mut(), "{},", field.name().to_string_lossy()).unwrap(); + } + write!(self.borrow_mut(), "]").unwrap(); + self.clone() + } + fn copy(&mut self, copy: CopyOutResponse) -> Self { + write!( + self.borrow_mut(), + "[copy={:?} {:?}", + copy.format(), + copy.format_codes() + ) + .unwrap(); + self.clone() + } + fn error(&mut self, error: ErrorResponse) { + for field in error.fields() { + if field.etype() as char == 'C' { + write!( + self.borrow_mut(), + "[error {}]", + field.value().to_string_lossy() + ) + .unwrap(); + return; + } + } + write!(self.borrow_mut(), "[error ??? {:?}]", error).unwrap(); + } + } + + impl DataSink for Rc> { + fn row(&mut self, row: DataRow) { + write!(self.borrow_mut(), "[").unwrap(); + for value in row.values() { + write!(self.borrow_mut(), "{},", value.to_string_lossy()).unwrap(); + } + write!(self.borrow_mut(), "]").unwrap(); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => { + write!( + self.borrow_mut(), + " done={}]", + complete.tag().to_string_lossy() + ) + .unwrap(); + } + Err(error) => { + for field in error.fields() { + if field.etype() as char == 'C' { + write!( + self.borrow_mut(), + "[error {}]]", + field.value().to_string_lossy() + ) + .unwrap(); + return DoneHandling::Handled; + } + } + write!(self.borrow_mut(), "[error ??? {:?}]]", error).unwrap(); + } } - x => { - warn!("Connection state was not ready: {x:?}"); - return Err(PGError::InvalidState); + DoneHandling::Handled + } + } + + impl CopyDataSink for Rc> { + fn data(&mut self, data: CopyData) { + write!( + self.borrow_mut(), + "[{}]", + String::from_utf8_lossy(data.data().as_ref()) + ) + .unwrap(); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => { + write!( + self.borrow_mut(), + " done={}]", + complete.tag().to_string_lossy() + ) + .unwrap(); + } + Err(error) => { + for field in error.fields() { + if field.etype() as char == 'C' { + write!( + self.borrow_mut(), + "[error {}]]", + field.value().to_string_lossy() + ) + .unwrap(); + return DoneHandling::Handled; + } + } + write!(self.borrow_mut(), "[error ??? {:?}]]", error).unwrap(); + } + } + DoneHandling::Handled + } + } + + async fn read_expect(stream: &mut S, expected: &[u8]) { + let mut buf = vec![0u8; expected.len()]; + stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, expected); + } + + /// Perform a test using captured binary protocol data from a real server. + async fn run_expect( + query_task: impl FnOnce(Client, Rc>) -> F + 'static, + expect: &'static [(&[u8], &[u8], &str)], + ) { + let f = async move { + let (mut s1, s2) = tokio::io::duplex(1024 * 1024); + + let (client, task) = Client::new_raw(RawClient::new(s2, ConnectionParams::default())); + let task_handle = tokio::task::spawn_local(task); + + let handle = tokio::task::spawn_local(async move { + let log = Rc::new(RefCell::new(String::new())); + query_task(client, log.clone()).await; + Rc::try_unwrap(log).unwrap().into_inner() + }); + + let mut log_expect = String::new(); + for (read, write, expect) in expect { + // Query[text=""] + eprintln!("read {read:?}"); + read_expect(&mut s1, read).await; + eprintln!("write {write:?}"); + s1.write_all(write).await.unwrap(); + log_expect.push_str(expect); } + + let log = handle.await.unwrap(); + + assert_eq!(log, log_expect); + + // EOF to trigger the task to exit + drop(s1); + + task_handle.await.unwrap().unwrap(); }; - let message = builder::Query { query: &query }.to_vec(); - self.write(&message).await?; - rx.recv().await; - Ok(()) + let local = LocalSet::new(); + let task = local.spawn_local(f); + + timeout(Duration::from_secs(1), local).await.unwrap(); + + // Ensure we detect panics inside the task + task.await.unwrap(); } -} -#[cfg(test)] -mod tests {} + #[test_log::test(tokio::test)] + async fn query_select_1() { + run_expect( + |client, log| async move { + client.query("SELECT 1", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 0d53454c 45435420 3100"), + // T, D, C, Z + &hex!("54000000 2100013f 636f6c75 6d6e3f00 00000000 00000000 00170004 ffffffff 00004400 00000b00 01000000 01314300 00000d53 454c4543 54203100 5a000000 0549"), + "[table=[?column?,][1,] done=SELECT 1]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_select_1_limit_0() { + run_expect( + |client, log| async move { + client.query("SELECT 1 LIMIT 0", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 1553454c 45435420 31204c49 4d495420 3000"), + // T, C, Z + &hex!("54000000 2100013f 636f6c75 6d6e3f00 00000000 00000000 00170004 ffffffff 00004300 00000d53 454c4543 54203000 5a000000 0549"), + "[table=[?column?,] done=SELECT 0]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_copy_1() { + run_expect( + |client, log| async move { + client.query("copy (select 1) to stdout;", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 1f636f70 79202873 656c6563 74203129 20746f20 7374646f 75743b00"), + // H, d, c, C, Z + &hex!("48000000 09000001 00006400 00000631 0a630000 00044300 00000b43 4f505920 31005a00 00000549"), + "[copy=0 [0][1\n] done=COPY 1]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_copy_1_limit_0() { + run_expect( + |client, log| async move { + client.query("copy (select 1 limit 0) to stdout;", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 27636f70 79202873 656c6563 74203120 6c696d69 74203029 20746f20 7374646f 75743b00"), + // H, c, C, Z + &hex!("48000000 09000001 00006300 00000443 0000000b 434f5059 2030005a 00000005 49"), + "[copy=0 [0] done=COPY 0]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_copy_with_error_rows() { + run_expect( + |client, log| async move { + client.query("copy (select case when id = 2 then id/(id-2) else id end from (select generate_series(1,2) as id)) to stdout;", log.clone()).await.unwrap(); + }, + &[( + &hex!(""" + 51000000 72636f70 79202873 656c6563 + 74206361 73652077 68656e20 6964203d + 20322074 68656e20 69642f28 69642d32 + 2920656c 73652069 6420656e 64206672 + 6f6d2028 73656c65 63742067 656e6572 + 6174655f 73657269 65732831 2c322920 + 61732069 64292920 746f2073 74646f75 + 743b00 + """), + // H, d, E, Z + &hex!(""" + 48000000 09000001 00006400 00000631 + 0a450000 00415345 52524f52 00564552 + 524f5200 43323230 3132004d 64697669 + 73696f6e 20627920 7a65726f 0046696e + 742e6300 4c383431 0052696e 74346469 + 7600005a 00000005 49 + """), + "[copy=0 [0][1\n][error 22012]]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_error() { + run_expect( + |client, log| async move { + client.query("do $$begin raise exception 'hi'; end$$;", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 2c646f20 24246265 67696e20 72616973 65206578 63657074 696f6e20 27686927 3b20656e 6424243b 00"), + // E, Z + &hex!(""" + 45000000 75534552 524f5200 56455252 + 4f520043 50303030 31004d68 69005750 + 4c2f7067 53514c20 66756e63 74696f6e + 20696e6c 696e655f 636f6465 5f626c6f + 636b206c 696e6520 31206174 20524149 + 53450046 706c5f65 7865632e 63004c33 + 39313100 52657865 635f7374 6d745f72 + 61697365 00005a00 00000549 + """), + "[error P0001]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_empty_do() { + run_expect( + |client, log| async move { + client + .query("do $$begin end$$;", log.clone()) + .await + .unwrap(); + }, + &[( + &hex!("51000000 16646f20 24246265 67696e20 656e6424 243b00"), + // C, Z + &hex!(""" + 43000000 07444f00 5a000000 0549 + """), + "", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_error_with_rows() { + run_expect( + |client, log| async move { + client.query("select case when id = 2 then id/(id-2) else 1 end from (select 1 as id union all select 2 as id);", log.clone()).await.unwrap(); + }, + &[( + &hex!(""" + 51000000 6673656c 65637420 63617365 + 20776865 6e206964 203d2032 20746865 + 6e206964 2f286964 2d322920 656c7365 + 20312065 6e642066 726f6d20 2873656c + 65637420 31206173 20696420 756e696f + 6e20616c 6c207365 6c656374 20322061 + 73206964 293b00 + """), + // T, D, E, Z + &hex!(""" + 54000000 1d000163 61736500 00000000 + 00000000 00170004 ffffffff 00004400 + 00000b00 01000000 01314500 00004153 + 4552524f 52005645 52524f52 00433232 + 30313200 4d646976 6973696f 6e206279 + 207a6572 6f004669 6e742e63 004c3834 + 31005269 6e743464 69760000 5a000000 + 0549 + """), + "[table=[case,][1,][error 22012]]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_second_errors() { + run_expect( + |client, log| async move { + client + .query("select; select 1/0;", log.clone()) + .await + .unwrap(); + }, + &[( + &hex!("51000000 1873656c 6563743b 2073656c 65637420 312f303b 00"), + // T, D, C, E, Z + &hex!(""" + 54000000 06000044 00000006 00004300 + 00000d53 454c4543 54203100 45000000 + 41534552 524f5200 56455252 4f520043 + 32323031 32004d64 69766973 696f6e20 + 6279207a 65726f00 46696e74 2e63004c + 38343100 52696e74 34646976 00005a00 + 00000549 + """), + "[table=[][] done=SELECT 1][error 22012]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_notification() { + run_expect( + |client, log| async move { + client + .query("listen a; select pg_notify('a','b')", log.clone()) + .await + .unwrap(); + }, + &[( + &hex!( + " + 51000000 286c6973 74656e20 613b2073 + 656c6563 74207067 5f6e6f74 69667928 + 2761272c 27622729 00 + " + ), + // C, T, D, C, A, Z + &hex!( + " + 43000000 0b4c4953 54454e00 54000000 + 22000170 675f6e6f 74696679 00000000 + 00000000 0008e600 04ffffff ff000044 + 0000000a 00010000 00004300 00000d53 + 454c4543 54203100 41000000 0c002cba + 5f610062 005a0000 000549 + " + ), + "[table=[pg_notify,][,] done=SELECT 1]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_two_empty() { + run_expect( + |client, log| async move { + client.query("", log.clone()).await.unwrap(); + client.query("", log.clone()).await.unwrap(); + }, + &[ + ( + &hex!("51000000 0500"), + // I, Z + &hex!("49000000 045a0000 000549"), + "", + ), + ( + &hex!("51000000 0500"), + // I, Z + &hex!("49000000 045a0000 000549"), + "", + ), + ], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_two_error() { + run_expect( + |client, log| async move { + client.query(".", log.clone()).await.unwrap(); + client.query(".", log.clone()).await.unwrap(); + }, + &[ + ( + &hex!("51000000 062e00"), + // E, Z + &hex!(""" + 45000000 59534552 524f5200 56455252 + 4f520043 34323630 31004d73 796e7461 + 78206572 726f7220 6174206f 72206e65 + 61722022 2e220050 31004673 63616e2e + 6c004c31 32343400 52736361 6e6e6572 + 5f797965 72726f72 00005a00 00000549 + """), + "[error 42601]", + ), + ( + &hex!("51000000 062e00"), + // E, Z + &hex!(""" + 45000000 59534552 524f5200 56455252 + 4f520043 34323630 31004d73 796e7461 + 78206572 726f7220 6174206f 72206e65 + 61722022 2e220050 31004673 63616e2e + 6c004c31 32343400 52736361 6e6e6572 + 5f797965 72726f72 00005a00 00000549 + """), + "[error 42601]", + ), + ], + ) + .await; + } +} diff --git a/rust/pgrust/src/connection/flow.rs b/rust/pgrust/src/connection/flow.rs new file mode 100644 index 00000000000..eed27da32eb --- /dev/null +++ b/rust/pgrust/src/connection/flow.rs @@ -0,0 +1,1231 @@ +//! Postgres flow notes: +//! +//! +//! +//! +//! +//! Extended query messages Parse, Bind, Describe, Execute, Close put the server +//! into a "skip-til-sync" mode when erroring. All messages other than Terminate (including +//! those not part of the extended query protocol) are skipped until an explicit Sync message is received. +//! +//! Sync closes _implicit_ but not _explicit_ transactions. +//! +//! Both Query and Execute may return COPY responses rather than rows. In the case of Query, +//! RowDescription + DataRow is replaced by CopyOutResponse + CopyData + CopyDone. In the case +//! of Execute, describing the portal will return NoData, but Execute will return CopyOutResponse + +//! CopyData + CopyDone. + +use std::{cell::RefCell, num::NonZeroU32, rc::Rc}; + +use crate::protocol::postgres::{ + builder, + data::{ + BindComplete, CloseComplete, CommandComplete, CopyData, CopyDone, CopyOutResponse, DataRow, + EmptyQueryResponse, ErrorResponse, Message, NoData, NoticeResponse, ParameterDescription, + ParseComplete, PortalSuspended, ReadyForQuery, RowDescription, + }, +}; +use db_proto::{match_message, Encoded}; + +#[derive(Debug, Clone, Copy)] +pub enum Param<'a> { + Null, + Text(&'a str), + Binary(&'a [u8]), +} + +#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +#[repr(transparent)] +pub struct Oid(u32); + +impl Oid { + pub fn unspecified() -> Self { + Self(0) + } + + pub fn from(oid: NonZeroU32) -> Self { + Self(oid.get()) + } +} + +#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +#[repr(transparent)] +pub struct Format(i16); + +impl Format { + pub fn text() -> Self { + Self(0) + } + + pub fn binary() -> Self { + Self(1) + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(i32)] +pub enum MaxRows { + Unlimited, + Limited(NonZeroU32), +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Portal<'a>(pub &'a str); + +#[derive(Debug, Clone, Copy, Default)] +pub struct Statement<'a>(pub &'a str); + +pub trait Flow { + fn to_vec(&self) -> Vec; +} + +/// Performs a prepared statement parse operation. +/// +/// Handles: +/// - `ParseComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct ParseFlow<'a> { + pub name: Statement<'a>, + pub query: &'a str, + pub param_types: &'a [Oid], +} + +/// Performs a prepared statement bind operation. +/// +/// Handles: +/// - `BindComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct BindFlow<'a> { + pub portal: Portal<'a>, + pub statement: Statement<'a>, + pub params: &'a [Param<'a>], + pub result_format_codes: &'a [Format], +} + +/// Performs a prepared statement execute operation. +/// +/// Handles: +/// - `CommandComplete` +/// - `DataRow` +/// - `PortalSuspended` +/// - `CopyOutResponse` +/// - `CopyData` +/// - `CopyDone` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct ExecuteFlow<'a> { + pub portal: Portal<'a>, + pub max_rows: MaxRows, +} + +/// Performs a portal describe operation. +/// +/// Handles: +/// - `RowDescription` +/// - `NoData` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct DescribePortalFlow<'a> { + pub name: Portal<'a>, +} + +/// Performs a statement describe operation. +/// +/// Handles: +/// - `RowDescription` +/// - `NoData` +/// - `ParameterDescription` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct DescribeStatementFlow<'a> { + pub name: Statement<'a>, +} + +/// Performs a portal close operation. +/// +/// Handles: +/// - `CloseComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct ClosePortalFlow<'a> { + pub name: Portal<'a>, +} + +/// Performs a statement close operation. +/// +/// Handles: +/// - `CloseComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct CloseStatementFlow<'a> { + pub name: Statement<'a>, +} + +/// Performs a query operation. +/// +/// Handles: +/// - `EmptyQueryResponse`: If no queries were specified in the text +/// - `CommandComplete`: For each fully-completed query +/// - `RowDescription`: For each query that returns data +/// - `DataRow`: For each row returned by a query +/// - `CopyOutResponse`: For each query that returns copy data +/// - `CopyData`: For each chunk of copy data returned by a query +/// - `CopyDone`: For each query that returns copy data +/// - `ErrorResponse`: For the first failed query +#[derive(Debug, Clone, Copy)] +struct QueryFlow<'a> { + pub query: &'a str, +} + +impl<'a> Flow for ParseFlow<'a> { + fn to_vec(&self) -> Vec { + let param_types = bytemuck::cast_slice(self.param_types); + builder::Parse { + statement: self.name.0, + query: self.query, + param_types, + } + .to_vec() + } +} + +impl<'a> Flow for BindFlow<'a> { + fn to_vec(&self) -> Vec { + let mut format_codes = Vec::with_capacity(self.params.len()); + let mut values = Vec::with_capacity(self.params.len()); + + for param in self.params { + match param { + Param::Null => { + format_codes.push(0); + values.push(Encoded::Null); + } + Param::Text(value) => { + format_codes.push(0); + values.push(Encoded::Value(value.as_bytes())); + } + Param::Binary(value) => { + format_codes.push(1); + values.push(Encoded::Value(value)); + } + } + } + + let result_format_codes = bytemuck::cast_slice(self.result_format_codes); + + builder::Bind { + portal: self.portal.0, + statement: self.statement.0, + format_codes: &format_codes, + values: &values, + result_format_codes, + } + .to_vec() + } +} + +impl<'a> Flow for ExecuteFlow<'a> { + fn to_vec(&self) -> Vec { + let max_rows = match self.max_rows { + MaxRows::Unlimited => 0, + MaxRows::Limited(n) => n.get() as i32, + }; + builder::Execute { + portal: self.portal.0, + max_rows, + } + .to_vec() + } +} + +impl<'a> Flow for DescribePortalFlow<'a> { + fn to_vec(&self) -> Vec { + builder::Describe { + name: self.name.0, + dtype: b'P', + } + .to_vec() + } +} + +impl<'a> Flow for DescribeStatementFlow<'a> { + fn to_vec(&self) -> Vec { + builder::Describe { + name: self.name.0, + dtype: b'S', + } + .to_vec() + } +} + +impl<'a> Flow for ClosePortalFlow<'a> { + fn to_vec(&self) -> Vec { + builder::Close { + name: self.name.0, + ctype: b'P', + } + .to_vec() + } +} + +impl<'a> Flow for CloseStatementFlow<'a> { + fn to_vec(&self) -> Vec { + builder::Close { + name: self.name.0, + ctype: b'S', + } + .to_vec() + } +} + +impl<'a> Flow for QueryFlow<'a> { + fn to_vec(&self) -> Vec { + builder::Query { query: self.query }.to_vec() + } +} + +pub(crate) enum MessageResult { + Continue, + Done, + SkipUntilSync, + Unknown, + UnexpectedState { complaint: &'static str }, +} + +pub(crate) trait MessageHandler { + fn handle(&mut self, message: Message) -> MessageResult; + fn name(&self) -> &'static str; + fn is_sync(&self) -> bool { + false + } +} + +pub(crate) struct SyncMessageHandler; + +impl MessageHandler for SyncMessageHandler { + fn handle(&mut self, message: Message) -> MessageResult { + if ReadyForQuery::try_new(&message).is_some() { + return MessageResult::Done; + } + MessageResult::Unknown + } + fn name(&self) -> &'static str { + "Sync" + } + fn is_sync(&self) -> bool { + true + } +} + +impl MessageHandler for (&'static str, F) +where + F: for<'a> FnMut(Message<'a>) -> MessageResult, +{ + fn handle(&mut self, message: Message) -> MessageResult { + (self.1)(message) + } + fn name(&self) -> &'static str { + self.0 + } +} + +pub trait FlowWithSink { + fn visit_flow(&self, f: impl FnMut(&dyn Flow)); + fn make_handler(self) -> Box; +} + +pub trait SimpleFlowSink { + fn handle(&mut self, result: Result<(), ErrorResponse>); +} + +impl SimpleFlowSink for () { + fn handle(&mut self, _: Result<(), ErrorResponse>) {} +} + +impl FnMut(Result<(), ErrorResponse>)> SimpleFlowSink for F { + fn handle(&mut self, result: Result<(), ErrorResponse>) { + (self)(result) + } +} + +impl FlowWithSink for (ParseFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("Parse", move |message: Message<'_>| { + if ParseComplete::try_new(&message).is_some() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Some(msg) = ErrorResponse::try_new(&message) { + self.1.handle(Err(msg)); + return MessageResult::SkipUntilSync; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (BindFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("Bind", move |message: Message<'_>| { + if BindComplete::try_new(&message).is_some() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Some(msg) = ErrorResponse::try_new(&message) { + self.1.handle(Err(msg)); + return MessageResult::SkipUntilSync; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (ClosePortalFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("ClosePortal", move |message: Message<'_>| { + if CloseComplete::try_new(&message).is_some() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Some(msg) = ErrorResponse::try_new(&message) { + self.1.handle(Err(msg)); + return MessageResult::SkipUntilSync; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (CloseStatementFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("CloseStatement", move |message: Message<'_>| { + if CloseComplete::try_new(&message).is_some() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Some(msg) = ErrorResponse::try_new(&message) { + self.1.handle(Err(msg)); + return MessageResult::Done; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (ExecuteFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(ExecuteMessageHandler { + sink: self.1, + data: None, + copy: None, + }) + } +} + +impl FlowWithSink for (QueryFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(QueryMessageHandler { + sink: self.1, + data: None, + copy: None, + }) + } +} + +impl FlowWithSink for (DescribePortalFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(DescribeMessageHandler { sink: self.1 }) + } +} + +impl FlowWithSink for (DescribeStatementFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(DescribeMessageHandler { sink: self.1 }) + } +} + +pub trait DescribeSink { + fn params(&mut self, params: ParameterDescription); + fn rows(&mut self, rows: RowDescription); + fn error(&mut self, error: ErrorResponse); +} + +impl DescribeSink for () { + fn params(&mut self, _: ParameterDescription) {} + fn rows(&mut self, _: RowDescription) {} + fn error(&mut self, _: ErrorResponse) {} +} + +impl DescribeSink for F +where + F: for<'a> FnMut(RowDescription<'a>), +{ + fn rows(&mut self, rows: RowDescription) { + (self)(rows) + } + fn params(&mut self, _params: ParameterDescription) {} + fn error(&mut self, _error: ErrorResponse) {} +} + +impl DescribeSink for (F1, F2) +where + F1: for<'a> FnMut(ParameterDescription<'a>), + F2: for<'a> FnMut(RowDescription<'a>), +{ + fn params(&mut self, params: ParameterDescription) { + (self.0)(params) + } + fn rows(&mut self, rows: RowDescription) { + (self.1)(rows) + } + fn error(&mut self, _error: ErrorResponse) {} +} + +struct DescribeMessageHandler { + sink: S, +} + +impl MessageHandler for DescribeMessageHandler { + fn name(&self) -> &'static str { + "Describe" + } + fn handle(&mut self, message: Message) -> MessageResult { + match_message!(Ok(message), Backend { + (ParameterDescription as params) => { + self.sink.params(params); + return MessageResult::Continue; + }, + (RowDescription as rows) => { + self.sink.rows(rows); + return MessageResult::Done; + }, + (NoData) => { + return MessageResult::Done; + }, + (ErrorResponse as err) => { + self.sink.error(err); + return MessageResult::SkipUntilSync; + }, + _unknown => { + return MessageResult::Unknown; + } + }) + } +} + +pub trait ExecuteSink { + type Output: ExecuteDataSink; + type CopyOutput: CopyDataSink; + + fn rows(&mut self) -> Self::Output; + fn copy(&mut self, copy: CopyOutResponse) -> Self::CopyOutput; + fn complete(&mut self, _complete: ExecuteCompletion) {} + fn notice(&mut self, _: NoticeResponse) {} + fn error(&mut self, error: ErrorResponse); +} + +pub enum ExecuteCompletion<'a> { + PortalSuspended(PortalSuspended<'a>), + CommandComplete(CommandComplete<'a>), +} + +impl ExecuteSink for () { + type Output = (); + type CopyOutput = (); + fn rows(&mut self) {} + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, _: ErrorResponse) {} +} + +impl ExecuteSink for (F1, F2) +where + F1: for<'a> FnMut() -> S, + F2: for<'a> FnMut(ErrorResponse<'a>), + S: ExecuteDataSink, +{ + type Output = S; + type CopyOutput = (); + fn rows(&mut self) -> S { + (self.0)() + } + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, error: ErrorResponse) { + (self.1)(error) + } +} + +impl ExecuteSink for (F1, F2, F3) +where + F1: for<'a> FnMut() -> S, + F2: for<'a> FnMut(CopyOutResponse<'a>) -> T, + F3: for<'a> FnMut(ErrorResponse<'a>), + S: ExecuteDataSink, + T: CopyDataSink, +{ + type Output = S; + type CopyOutput = T; + fn rows(&mut self) -> S { + (self.0)() + } + fn copy(&mut self, copy: CopyOutResponse) -> T { + (self.1)(copy) + } + fn error(&mut self, error: ErrorResponse) { + (self.2)(error) + } +} + +pub trait ExecuteDataSink { + /// Sink a row of data. + fn row(&mut self, values: DataRow); + /// Handle the completion of a command. If unimplemented, will be redirected to the parent. + #[must_use] + fn done(&mut self, _result: Result) -> DoneHandling { + DoneHandling::RedirectToParent + } +} + +impl ExecuteDataSink for () { + fn row(&mut self, _: DataRow) {} +} + +impl ExecuteDataSink for F +where + F: for<'a> Fn(DataRow<'a>), +{ + fn row(&mut self, values: DataRow) { + (self)(values) + } +} + +/// A sink capable of handling standard query and COPY (out direction) messages. +pub trait QuerySink { + type Output: DataSink; + type CopyOutput: CopyDataSink; + + fn rows(&mut self, rows: RowDescription) -> Self::Output; + fn copy(&mut self, copy: CopyOutResponse) -> Self::CopyOutput; + fn complete(&mut self, _complete: CommandComplete) {} + fn notice(&mut self, _: NoticeResponse) {} + fn error(&mut self, error: ErrorResponse); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DoneHandling { + Handled, + RedirectToParent, +} + +pub trait DataSink { + /// Sink a row of data. + fn row(&mut self, values: DataRow); + /// Handle the completion of a command. If unimplemented, will be redirected to the parent. + #[must_use] + fn done(&mut self, _result: Result) -> DoneHandling { + DoneHandling::RedirectToParent + } +} + +pub trait CopyDataSink { + /// Sink a chunk of COPY data. + fn data(&mut self, values: CopyData); + /// Handle the completion of a COPY operation. If unimplemented, will be redirected to the parent. + #[must_use] + fn done(&mut self, _result: Result) -> DoneHandling { + DoneHandling::RedirectToParent + } +} + +impl QuerySink for Box +where + Q: QuerySink + 'static, +{ + type Output = Box; + type CopyOutput = Box; + fn rows(&mut self, rows: RowDescription) -> Self::Output { + Box::new(self.as_mut().rows(rows)) + } + fn copy(&mut self, copy: CopyOutResponse) -> Self::CopyOutput { + Box::new(self.as_mut().copy(copy)) + } + fn complete(&mut self, _complete: CommandComplete) { + self.as_mut().complete(_complete) + } + fn error(&mut self, error: ErrorResponse) { + self.as_mut().error(error) + } +} + +impl QuerySink for () { + type Output = (); + type CopyOutput = (); + fn rows(&mut self, _: RowDescription) {} + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, _: ErrorResponse) {} +} + +impl QuerySink for (F1, F2) +where + F1: for<'a> FnMut(RowDescription<'a>) -> S, + F2: for<'a> FnMut(ErrorResponse<'a>), + S: DataSink, +{ + type Output = S; + type CopyOutput = (); + fn rows(&mut self, rows: RowDescription) -> S { + (self.0)(rows) + } + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, error: ErrorResponse) { + (self.1)(error) + } +} + +impl QuerySink for (F1, F2, F3) +where + F1: for<'a> FnMut(RowDescription<'a>) -> S, + F2: for<'a> FnMut(CopyOutResponse<'a>) -> T, + F3: for<'a> FnMut(ErrorResponse<'a>), + S: DataSink, + T: CopyDataSink, +{ + type Output = S; + type CopyOutput = T; + fn rows(&mut self, rows: RowDescription) -> S { + (self.0)(rows) + } + fn copy(&mut self, copy: CopyOutResponse) -> T { + (self.1)(copy) + } + fn error(&mut self, error: ErrorResponse) { + (self.2)(error) + } +} + +impl DataSink for () { + fn row(&mut self, _: DataRow) {} +} + +impl DataSink for F +where + F: for<'a> Fn(DataRow<'a>), +{ + fn row(&mut self, values: DataRow) { + (self)(values) + } +} + +impl DataSink for Box { + fn row(&mut self, values: DataRow) { + self.as_mut().row(values) + } + fn done(&mut self, result: Result) -> DoneHandling { + self.as_mut().done(result) + } +} + +impl CopyDataSink for () { + fn data(&mut self, _: CopyData) {} +} + +impl CopyDataSink for F +where + F: for<'a> FnMut(CopyData<'a>), +{ + fn data(&mut self, values: CopyData) { + (self)(values) + } +} + +impl CopyDataSink for Box { + fn data(&mut self, values: CopyData) { + self.as_mut().data(values) + } + fn done(&mut self, result: Result) -> DoneHandling { + self.as_mut().done(result) + } +} + +pub(crate) struct ExecuteMessageHandler { + pub sink: Q, + pub data: Option, + pub copy: Option, +} + +impl MessageHandler for ExecuteMessageHandler { + fn name(&self) -> &'static str { + "Execute" + } + fn handle(&mut self, message: Message) -> MessageResult { + match_message!(Ok(message), Backend { + (CopyOutResponse as copy) => { + let sink = std::mem::replace(&mut self.copy, Some(self.sink.copy(copy))); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "copy sink exists" }; + } + }, + (CopyData as data) => { + if let Some(sink) = &mut self.copy { + sink.data(data); + } else { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + (CopyDone) => { + if self.copy.is_none() { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + (DataRow as row) => { + if self.data.is_none() { + self.data = Some(self.sink.rows()); + } + let Some(sink) = &mut self.data else { + unreachable!() + }; + sink.row(row) + }, + (PortalSuspended as complete) => { + if let Some(mut sink) = std::mem::take(&mut self.data) { + if sink.done(Ok(ExecuteCompletion::PortalSuspended(complete))) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::PortalSuspended(complete)); + } + } else { + return MessageResult::UnexpectedState { complaint: "data sink does not exist" }; + } + return MessageResult::Done; + }, + (CommandComplete as complete) => { + if let Some(mut sink) = std::mem::take(&mut self.copy) { + // If COPY has started, route this to the COPY sink. + if sink.done(Ok(complete)) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::CommandComplete(complete)); + } + } else if let Some(mut sink) = std::mem::take(&mut self.data) { + // If data has started, route this to the data sink. + if sink.done(Ok(ExecuteCompletion::CommandComplete(complete))) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::CommandComplete(complete)); + } + } else { + // Otherwise, create a new data sink and route to there. + if self.sink.rows().done(Ok(ExecuteCompletion::CommandComplete(complete))) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::CommandComplete(complete)); + } + } + return MessageResult::Done; + }, + (EmptyQueryResponse) => { + // TODO: This should be exposed to the sink + return MessageResult::Done; + }, + + (ErrorResponse as err) => { + if let Some(mut sink) = std::mem::take(&mut self.copy) { + // If COPY has started, route this to the COPY sink. + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else if let Some(mut sink) = std::mem::take(&mut self.data) { + // If data has started, route this to the data sink. + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else { + // Otherwise, create a new data sink and route to there. + if self.sink.rows().done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } + + return MessageResult::SkipUntilSync; + }, + (NoticeResponse as notice) => { + self.sink.notice(notice); + }, + + _unknown => { + return MessageResult::Unknown; + } + }); + MessageResult::Continue + } +} + +pub(crate) struct QueryMessageHandler { + pub sink: Q, + pub data: Option, + pub copy: Option, +} + +impl MessageHandler for QueryMessageHandler { + fn name(&self) -> &'static str { + "Query" + } + fn handle(&mut self, message: Message) -> MessageResult { + match_message!(Ok(message), Backend { + (CopyOutResponse as copy) => { + let sink = std::mem::replace(&mut self.copy, Some(self.sink.copy(copy))); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "copy sink exists" }; + } + }, + (CopyData as data) => { + if let Some(sink) = &mut self.copy { + sink.data(data); + } else { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + (CopyDone) => { + if self.copy.is_none() { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + + (RowDescription as row) => { + let sink = std::mem::replace(&mut self.data, Some(self.sink.rows(row))); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "data sink exists" }; + } + }, + (DataRow as row) => { + if let Some(sink) = &mut self.data { + sink.row(row) + } else { + return MessageResult::UnexpectedState { complaint: "data sink does not exist" }; + } + }, + (CommandComplete as complete) => { + let sink = std::mem::take(&mut self.data); + if let Some(mut sink) = sink { + if sink.done(Ok(complete)) == DoneHandling::RedirectToParent { + self.sink.complete(complete); + } + } else { + let sink = std::mem::take(&mut self.copy); + if let Some(mut sink) = sink { + if sink.done(Ok(complete)) == DoneHandling::RedirectToParent { + self.sink.complete(complete); + } + } else { + self.sink.complete(complete); + } + } + }, + + (EmptyQueryResponse) => { + // Equivalent to CommandComplete, but no data was provided + let sink = std::mem::take(&mut self.data); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "data sink exists" }; + } else { + let sink = std::mem::take(&mut self.copy); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "copy sink exists" }; + } + } + }, + + (ErrorResponse as err) => { + // Depending on the state of the sink, we direct the error to + // the appropriate handler. + if let Some(mut sink) = std::mem::take(&mut self.data) { + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else if let Some(mut sink) = std::mem::take(&mut self.copy) { + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else { + // Top level errors must complete this operation + self.sink.error(err); + } + }, + (NoticeResponse as notice) => { + self.sink.notice(notice); + }, + + (ReadyForQuery) => { + // All operations are complete at this point. + if std::mem::take(&mut self.data).is_some() || std::mem::take(&mut self.copy).is_some() { + return MessageResult::UnexpectedState { complaint: "sink exists" }; + } + return MessageResult::Done; + }, + + _unknown => { + return MessageResult::Unknown; + } + }); + MessageResult::Continue + } +} + +#[derive(Default)] +pub struct PipelineBuilder { + handlers: Vec>, + messages: Vec, +} + +impl PipelineBuilder { + fn push_flow_with_sink(mut self, flow: impl FlowWithSink) -> Self { + flow.visit_flow(|flow| self.messages.extend_from_slice(&flow.to_vec())); + self.handlers.push(flow.make_handler()); + self + } + + /// Add a bind flow to the pipeline. + pub fn bind( + self, + portal: Portal, + statement: Statement, + params: &[Param], + result_format_codes: &[Format], + handler: impl SimpleFlowSink + 'static, + ) -> Self { + self.push_flow_with_sink(( + BindFlow { + portal, + statement, + params, + result_format_codes, + }, + handler, + )) + } + + /// Add a parse flow to the pipeline. + pub fn parse( + self, + name: Statement, + query: &str, + param_types: &[Oid], + handler: impl SimpleFlowSink + 'static, + ) -> Self { + self.push_flow_with_sink(( + ParseFlow { + name, + query, + param_types, + }, + handler, + )) + } + + /// Add an execute flow to the pipeline. + /// + /// Note that this may be a COPY statement. In that case, the description of the portal + /// will not show any data returned, and this will use the `CopySink` of the provided + /// sink. In addition, COPY operations do not respect the `max_rows` parameter. + pub fn execute( + self, + portal: Portal, + max_rows: MaxRows, + handler: impl ExecuteSink + 'static, + ) -> Self { + self.push_flow_with_sink((ExecuteFlow { portal, max_rows }, handler)) + } + + /// Add a close portal flow to the pipeline. + pub fn close_portal(self, name: Portal, handler: impl SimpleFlowSink + 'static) -> Self { + self.push_flow_with_sink((ClosePortalFlow { name }, handler)) + } + + /// Add a close statement flow to the pipeline. + pub fn close_statement(self, name: Statement, handler: impl SimpleFlowSink + 'static) -> Self { + self.push_flow_with_sink((CloseStatementFlow { name }, handler)) + } + + /// Add a describe portal flow to the pipeline. Note that this will describe + /// both parameters and rows. + pub fn describe_portal(self, name: Portal, handler: impl DescribeSink + 'static) -> Self { + self.push_flow_with_sink((DescribePortalFlow { name }, handler)) + } + + /// Add a describe statement flow to the pipeline. Note that this will describe + /// only the rows of the portal. + pub fn describe_statement(self, name: Statement, handler: impl DescribeSink + 'static) -> Self { + self.push_flow_with_sink((DescribeStatementFlow { name }, handler)) + } + + /// Add a query flow to the pipeline. + /// + /// Note that if a query fails, the pipeline will continue executing until it + /// completes or a non-query pipeline element fails. If a previous non-query + /// element of this pipeline failed, the query will not be executed. + pub fn query(self, query: &str, handler: impl QuerySink + 'static) -> Self { + self.push_flow_with_sink((QueryFlow { query }, handler)) + } + + pub fn build(self) -> Pipeline { + Pipeline { + handlers: self.handlers, + messages: self.messages, + } + } +} + +pub struct Pipeline { + pub(crate) handlers: Vec>, + pub(crate) messages: Vec, +} + +#[derive(Default)] +/// Accumulate raw messages from a flow. Useful mainly for testing. +pub struct FlowAccumulator { + data: Vec, + messages: Vec, +} + +impl FlowAccumulator { + pub fn push(&mut self, message: impl AsRef<[u8]>) { + self.messages.push(self.data.len()); + self.data.extend_from_slice(message.as_ref()); + } + + pub fn with_messages(&self, mut f: impl FnMut(Message)) { + for &offset in &self.messages { + // First get the message header + let message = Message::new(&self.data[offset..]).unwrap(); + let len = message.mlen(); + // Then resize the message to the correct length + let message = Message::new(&self.data[offset..offset + len + 1]).unwrap(); + f(message); + } + } +} + +impl QuerySink for Rc> { + type Output = Self; + type CopyOutput = Self; + fn rows(&mut self, message: RowDescription) -> Self { + self.borrow_mut().push(message); + self.clone() + } + fn copy(&mut self, message: CopyOutResponse) -> Self { + self.borrow_mut().push(message); + self.clone() + } + fn error(&mut self, message: ErrorResponse) { + self.borrow_mut().push(message); + } + fn complete(&mut self, complete: CommandComplete) { + self.borrow_mut().push(complete); + } + fn notice(&mut self, message: NoticeResponse) { + self.borrow_mut().push(message); + } +} + +impl ExecuteSink for Rc> { + type Output = Self; + type CopyOutput = Self; + + fn rows(&mut self) -> Self { + self.clone() + } + fn copy(&mut self, message: CopyOutResponse) -> Self { + self.borrow_mut().push(message); + self.clone() + } + fn error(&mut self, message: ErrorResponse) { + self.borrow_mut().push(message); + } + fn complete(&mut self, complete: ExecuteCompletion) { + match complete { + ExecuteCompletion::PortalSuspended(suspended) => self.borrow_mut().push(suspended), + ExecuteCompletion::CommandComplete(complete) => self.borrow_mut().push(complete), + } + } + fn notice(&mut self, message: NoticeResponse) { + self.borrow_mut().push(message); + } +} + +impl DataSink for Rc> { + fn row(&mut self, message: DataRow) { + self.borrow_mut().push(message); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => self.borrow_mut().push(complete), + Err(err) => self.borrow_mut().push(err), + }; + DoneHandling::Handled + } +} + +impl ExecuteDataSink for Rc> { + fn row(&mut self, message: DataRow) { + self.borrow_mut().push(message); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(ExecuteCompletion::PortalSuspended(suspended)) => self.borrow_mut().push(suspended), + Ok(ExecuteCompletion::CommandComplete(complete)) => self.borrow_mut().push(complete), + Err(err) => self.borrow_mut().push(err), + }; + DoneHandling::Handled + } +} + +impl CopyDataSink for Rc> { + fn data(&mut self, message: CopyData) { + self.borrow_mut().push(message); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => self.borrow_mut().push(complete), + Err(err) => self.borrow_mut().push(err), + }; + DoneHandling::Handled + } +} + +impl SimpleFlowSink for Rc> { + fn handle(&mut self, result: Result<(), ErrorResponse>) { + match result { + Ok(()) => (), + Err(err) => self.borrow_mut().push(err), + } + } +} + +impl DescribeSink for Rc> { + fn params(&mut self, params: ParameterDescription) { + self.borrow_mut().push(params); + } + fn rows(&mut self, rows: RowDescription) { + self.borrow_mut().push(rows); + } + fn error(&mut self, error: ErrorResponse) { + self.borrow_mut().push(error); + } +} diff --git a/rust/pgrust/src/connection/mod.rs b/rust/pgrust/src/connection/mod.rs index e15be003092..696367a9a0f 100644 --- a/rust/pgrust/src/connection/mod.rs +++ b/rust/pgrust/src/connection/mod.rs @@ -1,19 +1,22 @@ use std::collections::HashMap; -use crate::{ - errors::{edgedb::EdbError, PgServerError}, - protocol::ParseError, -}; - +use crate::errors::{edgedb::EdbError, PgServerError}; +use db_proto::ParseError; mod conn; pub mod dsn; +mod flow; pub mod openssl; +pub(crate) mod queue; mod raw_conn; mod stream; pub mod tokio; -pub use conn::Client; +pub use conn::{Client, PGConnError}; use dsn::HostType; +pub use flow::{ + CopyDataSink, DataSink, DoneHandling, ExecuteSink, FlowAccumulator, Format, MaxRows, Oid, + Param, Pipeline, PipelineBuilder, Portal, QuerySink, Statement, +}; pub use raw_conn::connect_raw_ssl; macro_rules! __invalid_state { @@ -91,7 +94,7 @@ pub struct Credentials { pub server_settings: HashMap, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, derive_more::From)] /// The resolved target of a connection attempt. pub enum ResolvedTarget { SocketAddr(std::net::SocketAddr), @@ -100,6 +103,17 @@ pub enum ResolvedTarget { } impl ResolvedTarget { + #[cfg(test)] + pub fn from_captive_server_listen_address(address: captive_postgres::ListenAddress) -> Self { + match address { + captive_postgres::ListenAddress::Tcp(addr) => Self::SocketAddr(addr), + #[cfg(unix)] + captive_postgres::ListenAddress::Unix(path) => { + Self::UnixSocketAddr(std::os::unix::net::SocketAddr::from_pathname(path).unwrap()) + } + } + } + /// Resolves the target addresses for a given host. pub fn to_addrs_sync(host: &dsn::Host) -> Result, std::io::Error> { use std::net::{SocketAddr, ToSocketAddrs}; diff --git a/rust/pgrust/src/connection/queue.rs b/rust/pgrust/src/connection/queue.rs new file mode 100644 index 00000000000..211beed5809 --- /dev/null +++ b/rust/pgrust/src/connection/queue.rs @@ -0,0 +1,166 @@ +use std::future::Future; +use std::ops::DerefMut; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A queue of futures that can be polled in order. +/// +/// Only one future will be active at a time. If no futures are active, the +/// waker will be triggered when the next future is submitted to the queue. +pub struct FutureQueue { + queue: tokio::sync::mpsc::UnboundedReceiver>>>, + sender: tokio::sync::mpsc::UnboundedSender>>>, + current: Option>>>, +} + +#[cfg(test)] +#[derive(Clone)] +pub struct FutureQueueSender { + sender: tokio::sync::mpsc::UnboundedSender>>>, +} + +#[cfg(test)] +impl FutureQueueSender { + pub fn submit(&self, future: impl Future + 'static) { + // This will never fail because the receiver still exists + self.sender.send(Box::pin(future)).unwrap(); + } +} + +impl FutureQueue { + #[cfg(test)] + pub fn sender(&self) -> FutureQueueSender { + FutureQueueSender { + sender: self.sender.clone(), + } + } + + pub fn submit(&self, future: impl Future + 'static) { + // This will never fail because we hold both ends of the channel. + self.sender.send(Box::pin(future)).unwrap(); + } + + /// Poll the current future, or no current future, poll for the next item + /// from the queue (and then poll that future). + pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + if let Some(future) = self.current.as_mut() { + match future.as_mut().poll(cx) { + Poll::Ready(output) => { + self.current = None; + return Poll::Ready(Some(output)); + } + Poll::Pending => return Poll::Pending, + } + } + + // If there is no current future, try to receive the next one from the queue. + let next = match self.queue.poll_recv(cx) { + Poll::Ready(Some(next)) => next, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + }; + + // Note that we loop around to poll this future until we get a Pending + // result. + self.current = Some(next); + } + } +} + +impl Default for FutureQueue { + fn default() -> Self { + let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); + Self { + queue: receiver, + sender, + current: None, + } + } +} + +impl futures::Stream for FutureQueue { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // We're Unpin + let this = self.deref_mut(); + this.poll_next_unpin(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::StreamExt; + use tokio::{ + task::LocalSet, + time::{sleep, Duration}, + }; + + #[tokio::test] + async fn test_basic_queue() { + LocalSet::new() + .run_until(async { + let mut queue = FutureQueue::default(); + let sender = queue.sender(); + + // Spawn a task that sends some futures + tokio::task::spawn_local(async move { + sleep(Duration::from_millis(10)).await; + sender.submit(async { 1 }); + sleep(Duration::from_millis(10)).await; + sender.submit(async { 2 }); + sleep(Duration::from_millis(10)).await; + sender.submit(async { 3 }); + }); + + // Collect results + let mut results = Vec::new(); + while let Some(value) = queue.next().await { + results.push(value); + if results.len() == 3 { + break; + } + } + + assert_eq!(results, vec![1, 2, 3]); + }) + .await; + } + + #[tokio::test] + async fn test_delayed_futures() { + LocalSet::new() + .run_until(async { + let mut queue = FutureQueue::default(); + let sender = queue.sender(); + + // Spawn task with delayed futures + tokio::task::spawn_local(async move { + sleep(Duration::from_millis(10)).await; + sender.submit(async { + sleep(Duration::from_millis(50)).await; + 1 + }); + sleep(Duration::from_millis(10)).await; + sender.submit(async { + sleep(Duration::from_millis(10)).await; + 2 + }); + }); + + // Even though second future completes first, results should be in order of sending + let mut results = Vec::new(); + while let Some(value) = queue.next().await { + results.push(value); + if results.len() == 2 { + break; + } + } + + assert_eq!(results, vec![1, 2]); + }) + .await; + } +} diff --git a/rust/pgrust/src/connection/raw_conn.rs b/rust/pgrust/src/connection/raw_conn.rs index 24d16885e23..07eaeec5e7c 100644 --- a/rust/pgrust/src/connection/raw_conn.rs +++ b/rust/pgrust/src/connection/raw_conn.rs @@ -1,5 +1,5 @@ use super::{ - stream::{Stream, StreamWithUpgrade, UpgradableStream}, + stream::{Stream, StreamWithUpgrade, UpgradableStream, UpgradableStreamChoice}, ConnectionError, Credentials, }; use crate::handshake::{ @@ -10,7 +10,8 @@ use crate::handshake::{ ConnectionSslRequirement, }; use crate::protocol::postgres::{FrontendBuilder, InitialBuilder}; -use crate::protocol::{postgres::data::SSLResponse, postgres::meta, StructBuffer}; +use crate::protocol::{postgres::data::SSLResponse, postgres::meta}; +use db_proto::StructBuffer; use gel_auth::AuthType; use std::collections::HashMap; use std::pin::Pin; @@ -147,10 +148,20 @@ pub struct RawClient where (B, C): StreamWithUpgrade, { - stream: UpgradableStream, + stream: UpgradableStreamChoice, params: ConnectionParams, } +impl RawClient { + /// Create a new raw client from a stream. The stream must be fully authenticated and ready. + pub fn new(stream: B, params: ConnectionParams) -> Self { + Self { + stream: UpgradableStreamChoice::Base(stream), + params, + } + } +} + impl RawClient where (B, C): StreamWithUpgrade, @@ -169,7 +180,10 @@ where cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_read(cx, buf) + match &mut self.get_mut().stream { + UpgradableStreamChoice::Base(base) => Pin::new(base).poll_read(cx, buf), + UpgradableStreamChoice::Upgrade(upgraded) => Pin::new(upgraded).poll_read(cx, buf), + } } } @@ -182,18 +196,47 @@ where cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_write(cx, buf) + match &mut self.get_mut().stream { + UpgradableStreamChoice::Base(base) => Pin::new(base).poll_write(cx, buf), + UpgradableStreamChoice::Upgrade(upgraded) => Pin::new(upgraded).poll_write(cx, buf), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + match &mut self.get_mut().stream { + UpgradableStreamChoice::Base(base) => Pin::new(base).poll_write_vectored(cx, bufs), + UpgradableStreamChoice::Upgrade(upgraded) => { + Pin::new(upgraded).poll_write_vectored(cx, bufs) + } + } + } + + fn is_write_vectored(&self) -> bool { + match &self.stream { + UpgradableStreamChoice::Base(base) => base.is_write_vectored(), + UpgradableStreamChoice::Upgrade(upgraded) => upgraded.is_write_vectored(), + } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_flush(cx) + match &mut self.get_mut().stream { + UpgradableStreamChoice::Base(base) => Pin::new(base).poll_flush(cx), + UpgradableStreamChoice::Upgrade(upgraded) => Pin::new(upgraded).poll_flush(cx), + } } fn poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_shutdown(cx) + match &mut self.get_mut().stream { + UpgradableStreamChoice::Base(base) => Pin::new(base).poll_shutdown(cx), + UpgradableStreamChoice::Upgrade(upgraded) => Pin::new(upgraded).poll_shutdown(cx), + } } } @@ -245,6 +288,8 @@ where .drive_bytes(&mut state, &buffer[..n], &mut struct_buffer, &mut stream) .await?; } + + let stream = stream.into_choice().unwrap(); Ok(RawClient { stream, params: update.params, diff --git a/rust/pgrust/src/connection/stream.rs b/rust/pgrust/src/connection/stream.rs index 6606a083c93..07cdc91273c 100644 --- a/rust/pgrust/src/connection/stream.rs +++ b/rust/pgrust/src/connection/stream.rs @@ -33,6 +33,7 @@ impl StreamWithUpgrade for (S, ()) { } } +#[derive(derive_more::Debug)] pub struct UpgradableStream where (B, C): StreamWithUpgrade, @@ -79,6 +80,19 @@ where )), } } + + /// Convert the inner stream into a choice between the base and the upgraded stream. + /// + /// If the inner stream is in the process of upgrading, return an error containing `self`. + pub fn into_choice(self) -> Result, Self> { + match self.inner { + UpgradableStreamInner::Base(base, _) => Ok(UpgradableStreamChoice::Base(base)), + UpgradableStreamInner::Upgraded(upgraded) => { + Ok(UpgradableStreamChoice::Upgrade(upgraded)) + } + UpgradableStreamInner::Upgrading => Err(self), + } + } } impl tokio::io::AsyncRead for UpgradableStream @@ -185,11 +199,26 @@ where } } +#[derive(derive_more::Debug)] enum UpgradableStreamInner where (B, C): StreamWithUpgrade, { + #[debug("Base(..)")] Base(B, C), + #[debug("Upgraded(..)")] Upgraded(<(B, C) as StreamWithUpgrade>::Upgrade), + #[debug("Upgrading")] Upgrading, } + +#[derive(derive_more::Debug)] +pub enum UpgradableStreamChoice +where + (B, C): StreamWithUpgrade, +{ + #[debug("Base(..)")] + Base(B), + #[debug("Upgrade(..)")] + Upgrade(<(B, C) as StreamWithUpgrade>::Upgrade), +} diff --git a/rust/pgrust/src/errors/mod.rs b/rust/pgrust/src/errors/mod.rs index 39e7e538e28..a414b2ee112 100644 --- a/rust/pgrust/src/errors/mod.rs +++ b/rust/pgrust/src/errors/mod.rs @@ -135,7 +135,7 @@ macro_rules! pg_error { )* paste!( - /// Postgres error codes. See https://www.postgresql.org/docs/current/errcodes-appendix.html. + /// Postgres error codes. See . #[derive(Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] pub enum PgError { $( diff --git a/rust/pgrust/src/handshake/client_state_machine.rs b/rust/pgrust/src/handshake/client_state_machine.rs index 7723f22b3d9..e0fa6583ea0 100644 --- a/rust/pgrust/src/handshake/client_state_machine.rs +++ b/rust/pgrust/src/handshake/client_state_machine.rs @@ -3,7 +3,6 @@ use crate::{ connection::{invalid_state, ConnectionError, Credentials, SslError}, errors::PgServerError, protocol::{ - match_message, postgres::data::{ AuthenticationCleartextPassword, AuthenticationMD5Password, AuthenticationMessage, AuthenticationOk, AuthenticationSASL, AuthenticationSASLContinue, @@ -11,10 +10,10 @@ use crate::{ ReadyForQuery, SSLResponse, }, postgres::{builder, FrontendBuilder, InitialBuilder}, - ParseError, }, }; use base64::Engine; +use db_proto::{match_message, ParseError}; use gel_auth::{ scram::{generate_salted_password, ClientEnvironment, ClientTransaction, Sha256Out}, AuthType, diff --git a/rust/pgrust/src/handshake/edgedb_server.rs b/rust/pgrust/src/handshake/edgedb_server.rs index faf3669a2e1..c5cf6915ab7 100644 --- a/rust/pgrust/src/handshake/edgedb_server.rs +++ b/rust/pgrust/src/handshake/edgedb_server.rs @@ -1,11 +1,9 @@ use crate::{ connection::ConnectionError, errors::edgedb::EdbError, - protocol::{ - edgedb::{data::*, *}, - match_message, ParseError, StructBuffer, - }, + protocol::edgedb::{data::*, *}, }; +use db_proto::{match_message, ParseError, StructBuffer}; use gel_auth::{ handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse}, AuthType, CredentialData, diff --git a/rust/pgrust/src/handshake/server_state_machine.rs b/rust/pgrust/src/handshake/server_state_machine.rs index 8c62c00b610..f7aa30f316f 100644 --- a/rust/pgrust/src/handshake/server_state_machine.rs +++ b/rust/pgrust/src/handshake/server_state_machine.rs @@ -5,12 +5,9 @@ use crate::{ PgError, PgErrorConnectionException, PgErrorFeatureNotSupported, PgErrorInvalidAuthorizationSpecification, PgServerError, PgServerErrorField, }, - protocol::{ - match_message, - postgres::{data::*, *}, - ParseError, StructBuffer, - }, + protocol::postgres::{data::*, *}, }; +use db_proto::{match_message, ParseError, StructBuffer}; use gel_auth::{ handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse}, AuthType, CredentialData, diff --git a/rust/pgrust/src/protocol/datatypes.rs b/rust/pgrust/src/protocol/datatypes.rs deleted file mode 100644 index 99433de3ce4..00000000000 --- a/rust/pgrust/src/protocol/datatypes.rs +++ /dev/null @@ -1,779 +0,0 @@ -use std::{marker::PhantomData, str::Utf8Error}; - -use uuid::Uuid; - -use super::{ - arrays::{array_access, Array, ArrayMeta}, - field_access, - writer::BufWriter, - Enliven, FieldAccess, Meta, ParseError, -}; - -pub mod meta { - pub use super::EncodedMeta as Encoded; - pub use super::LStringMeta as LString; - pub use super::LengthMeta as Length; - pub use super::RestMeta as Rest; - pub use super::UuidMeta as Uuid; - pub use super::ZTStringMeta as ZTString; -} - -/// Represents the remainder of data in a message. -#[derive(Debug, PartialEq, Eq)] -pub struct Rest<'a> { - buf: &'a [u8], -} - -field_access!(RestMeta); - -pub struct RestMeta {} -impl Meta for RestMeta { - fn name(&self) -> &'static str { - "Rest" - } -} -impl Enliven for RestMeta { - type WithLifetime<'a> = Rest<'a>; - type ForMeasure<'a> = &'a [u8]; - type ForBuilder<'a> = &'a [u8]; -} - -impl<'a> Rest<'a> {} - -impl<'a> AsRef<[u8]> for Rest<'a> { - fn as_ref(&self) -> &[u8] { - self.buf - } -} - -impl<'a> std::ops::Deref for Rest<'a> { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - self.buf - } -} - -impl PartialEq<[u8]> for Rest<'_> { - fn eq(&self, other: &[u8]) -> bool { - self.buf == other - } -} - -impl PartialEq<&[u8; N]> for Rest<'_> { - fn eq(&self, other: &&[u8; N]) -> bool { - self.buf == *other - } -} - -impl PartialEq<&[u8]> for Rest<'_> { - fn eq(&self, other: &&[u8]) -> bool { - self.buf == *other - } -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &RestMeta {} - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - Ok(buf.len()) - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result, ParseError> { - Ok(Rest { buf }) - } - #[inline(always)] - pub const fn measure(buf: &[u8]) -> usize { - buf.len() - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: &[u8]) { - buf.write(value) - } -} - -/// A zero-terminated string. -#[allow(unused)] -pub struct ZTString<'a> { - buf: &'a [u8], -} - -field_access!(ZTStringMeta); -array_access!(ZTStringMeta); - -pub struct ZTStringMeta {} -impl Meta for ZTStringMeta { - fn name(&self) -> &'static str { - "ZTString" - } -} - -impl Enliven for ZTStringMeta { - type WithLifetime<'a> = ZTString<'a>; - type ForMeasure<'a> = &'a str; - type ForBuilder<'a> = &'a str; -} - -impl std::fmt::Debug for ZTString<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - String::from_utf8_lossy(self.buf).fmt(f) - } -} - -impl<'a> ZTString<'a> { - pub fn to_owned(&self) -> Result { - std::str::from_utf8(self.buf).map(|s| s.to_owned()) - } - - pub fn to_str(&self) -> Result<&str, std::str::Utf8Error> { - std::str::from_utf8(self.buf) - } - - pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { - String::from_utf8_lossy(self.buf) - } - - pub fn to_bytes(&self) -> &[u8] { - self.buf - } -} - -impl PartialEq for ZTString<'_> { - fn eq(&self, other: &Self) -> bool { - self.buf == other.buf - } -} -impl Eq for ZTString<'_> {} - -impl PartialEq for ZTString<'_> { - fn eq(&self, other: &str) -> bool { - self.buf == other.as_bytes() - } -} - -impl PartialEq<&str> for ZTString<'_> { - fn eq(&self, other: &&str) -> bool { - self.buf == other.as_bytes() - } -} - -impl<'a> TryInto<&'a str> for ZTString<'a> { - type Error = Utf8Error; - fn try_into(self) -> Result<&'a str, Self::Error> { - std::str::from_utf8(self.buf) - } -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &ZTStringMeta {} - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - let mut i = 0; - loop { - if i >= buf.len() { - return Err(ParseError::TooShort); - } - if buf[i] == 0 { - return Ok(i + 1); - } - i += 1; - } - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result, ParseError> { - let buf = buf.split_at(buf.len() - 1).0; - Ok(ZTString { buf }) - } - #[inline(always)] - pub const fn measure(buf: &str) -> usize { - buf.len() + 1 - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: &str) { - buf.write(value.as_bytes()); - buf.write_u8(0); - } - #[inline(always)] - pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &str) { - buf.write(value.as_bytes()); - buf.write_u8(0); - } -} - -/// A length-prefixed string. -#[allow(unused)] -pub struct LString<'a> { - buf: &'a [u8], -} - -field_access!(LStringMeta); -array_access!(LStringMeta); - -pub struct LStringMeta {} -impl Meta for LStringMeta { - fn name(&self) -> &'static str { - "LString" - } -} - -impl Enliven for LStringMeta { - type WithLifetime<'a> = LString<'a>; - type ForMeasure<'a> = &'a str; - type ForBuilder<'a> = &'a str; -} - -impl std::fmt::Debug for LString<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - String::from_utf8_lossy(self.buf).fmt(f) - } -} - -impl<'a> LString<'a> { - pub fn to_owned(&self) -> Result { - std::str::from_utf8(self.buf).map(|s| s.to_owned()) - } - - pub fn to_str(&self) -> Result<&str, std::str::Utf8Error> { - std::str::from_utf8(self.buf) - } - - pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { - String::from_utf8_lossy(self.buf) - } - - pub fn to_bytes(&self) -> &[u8] { - self.buf - } -} - -impl PartialEq for LString<'_> { - fn eq(&self, other: &Self) -> bool { - self.buf == other.buf - } -} -impl Eq for LString<'_> {} - -impl PartialEq for LString<'_> { - fn eq(&self, other: &str) -> bool { - self.buf == other.as_bytes() - } -} - -impl PartialEq<&str> for LString<'_> { - fn eq(&self, other: &&str) -> bool { - self.buf == other.as_bytes() - } -} - -impl<'a> TryInto<&'a str> for LString<'a> { - type Error = Utf8Error; - fn try_into(self) -> Result<&'a str, Self::Error> { - std::str::from_utf8(self.buf) - } -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &LStringMeta {} - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - if buf.len() < 4 { - return Err(ParseError::TooShort); - } - let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; - Ok(4 + len) - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result, ParseError> { - if buf.len() < 4 { - return Err(ParseError::TooShort); - } - let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; - if buf.len() < 4 + len { - return Err(ParseError::TooShort); - } - Ok(LString { - buf: buf.split_at(4).1, - }) - } - #[inline(always)] - pub const fn measure(buf: &str) -> usize { - 4 + buf.len() - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: &str) { - let len = value.len() as u32; - buf.write(&len.to_be_bytes()); - buf.write(value.as_bytes()); - } - #[inline(always)] - pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &str) { - let len = value.len() as u32; - buf.write(&len.to_be_bytes()); - buf.write(value.as_bytes()); - } -} - -field_access!(UuidMeta); -array_access!(UuidMeta); - -pub struct UuidMeta {} -impl Meta for UuidMeta { - fn name(&self) -> &'static str { - "Uuid" - } -} - -impl Enliven for UuidMeta { - type WithLifetime<'a> = Uuid; - type ForMeasure<'a> = Uuid; - type ForBuilder<'a> = Uuid; -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &UuidMeta {} - } - - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - if buf.len() < 16 { - Err(ParseError::TooShort) - } else { - Ok(16) - } - } - - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result { - if let Some(bytes) = buf.first_chunk() { - Ok(Uuid::from_u128(::from_be_bytes(*bytes))) - } else { - Err(ParseError::TooShort) - } - } - - #[inline(always)] - pub const fn measure(_value: &Uuid) -> usize { - 16 - } - - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: Uuid) { - buf.write(value.as_bytes().as_slice()); - } - - #[inline(always)] - pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &Uuid) { - buf.write(value.as_bytes().as_slice()); - } -} - -#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] -/// An encoded row value. -pub enum Encoded<'a> { - #[default] - Null, - Value(&'a [u8]), -} - -impl<'a> AsRef> for Encoded<'a> { - fn as_ref(&self) -> &Encoded<'a> { - self - } -} - -field_access!(EncodedMeta); -array_access!(EncodedMeta); - -pub struct EncodedMeta {} -impl Meta for EncodedMeta { - fn name(&self) -> &'static str { - "Encoded" - } -} - -impl Enliven for EncodedMeta { - type WithLifetime<'a> = Encoded<'a>; - type ForMeasure<'a> = Encoded<'a>; - type ForBuilder<'a> = Encoded<'a>; -} - -impl<'a> Encoded<'a> {} - -impl PartialEq for Encoded<'_> { - fn eq(&self, other: &str) -> bool { - self == &Encoded::Value(other.as_bytes()) - } -} - -impl PartialEq<&str> for Encoded<'_> { - fn eq(&self, other: &&str) -> bool { - self == &Encoded::Value(other.as_bytes()) - } -} - -impl PartialEq<[u8]> for Encoded<'_> { - fn eq(&self, other: &[u8]) -> bool { - self == &Encoded::Value(other) - } -} - -impl PartialEq<&[u8]> for Encoded<'_> { - fn eq(&self, other: &&[u8]) -> bool { - self == &Encoded::Value(other) - } -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &EncodedMeta {} - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - const N: usize = std::mem::size_of::(); - if let Some(len) = buf.first_chunk::() { - let len = i32::from_be_bytes(*len); - if len < 0 { - Err(ParseError::InvalidData) - } else if len == -1 { - Ok(N) - } else if buf.len() < len as usize + N { - Err(ParseError::TooShort) - } else { - Ok(len as usize + N) - } - } else { - Err(ParseError::TooShort) - } - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result, ParseError> { - const N: usize = std::mem::size_of::(); - if let Some((len, array)) = buf.split_first_chunk::() { - let len = i32::from_be_bytes(*len); - if len < 0 { - Err(ParseError::InvalidData) - } else if len == -1 { - Ok(Encoded::Null) - } else if array.len() < len as _ { - Err(ParseError::TooShort) - } else { - Ok(Encoded::Value(array)) - } - } else { - Err(ParseError::TooShort) - } - } - #[inline(always)] - pub const fn measure(value: &Encoded) -> usize { - match value { - Encoded::Null => std::mem::size_of::(), - Encoded::Value(value) => value.len() + std::mem::size_of::(), - } - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: Encoded) { - Self::copy_to_buf_ref(buf, &value) - } - #[inline(always)] - pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &Encoded) { - match value { - Encoded::Null => buf.write(&[0xff, 0xff, 0xff, 0xff]), - Encoded::Value(value) => { - let len: i32 = value.len() as _; - buf.write(&len.to_be_bytes()); - buf.write(value); - } - } - } -} - -// We alias usize here. Note that if this causes trouble in the future we can -// probably work around this by adding a new "const value" function to -// FieldAccess. For now it works! -pub struct LengthMeta(#[allow(unused)] i32); -impl Enliven for LengthMeta { - type WithLifetime<'a> = usize; - type ForMeasure<'a> = usize; - type ForBuilder<'a> = usize; -} -impl Meta for LengthMeta { - fn name(&self) -> &'static str { - "len" - } -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &LengthMeta(0) - } - #[inline(always)] - pub const fn constant(value: usize) -> LengthMeta { - LengthMeta(value as i32) - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - match FieldAccess::::extract(buf) { - Ok(n) if n >= 0 => Ok(std::mem::size_of::()), - Ok(_) => Err(ParseError::InvalidData), - Err(e) => Err(e), - } - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result { - match FieldAccess::::extract(buf) { - Ok(n) if n >= 0 => Ok(n as _), - Ok(_) => Err(ParseError::InvalidData), - Err(e) => Err(e), - } - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: usize) { - FieldAccess::::copy_to_buf(buf, value as i32) - } - #[inline(always)] - pub fn copy_to_buf_rewind(buf: &mut BufWriter, rewind: usize, value: usize) { - FieldAccess::::copy_to_buf_rewind(buf, rewind, value as i32) - } -} - -macro_rules! basic_types { - ($($ty:ty)*) => { - $( - field_access!{$ty} - - impl Enliven for $ty { - type WithLifetime<'a> = $ty; - type ForMeasure<'a> = $ty; - type ForBuilder<'a> = $ty; - } - - impl Enliven for [$ty; S] { - type WithLifetime<'a> = [$ty; S]; - type ForMeasure<'a> = [$ty; S]; - type ForBuilder<'a> = [$ty; S]; - } - - #[allow(unused)] - impl FieldAccess<$ty> { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - struct Meta {} - impl $crate::protocol::Meta for Meta { - fn name(&self) -> &'static str { - stringify!($ty) - } - } - &Meta{} - } - #[inline(always)] - pub const fn constant(value: usize) -> $ty { - value as _ - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - let size = std::mem::size_of::<$ty>(); - if size > buf.len() { - Err($crate::protocol::ParseError::TooShort) - } else { - Ok(size) - } - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result<$ty, $crate::protocol::ParseError> { - if let Some(bytes) = buf.first_chunk() { - Ok(<$ty>::from_be_bytes(*bytes)) - } else { - Err($crate::protocol::ParseError::TooShort) - } - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: $ty) { - buf.write(&<$ty>::to_be_bytes(value)); - } - #[inline(always)] - pub fn copy_to_buf_rewind(buf: &mut BufWriter, rewind: usize, value: $ty) { - buf.write_rewind(rewind, &<$ty>::to_be_bytes(value)); - } - } - - #[allow(unused)] - impl FieldAccess<[$ty; S]> { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - struct Meta {} - impl $crate::protocol::Meta for Meta { - fn name(&self) -> &'static str { - // TODO: can we extract this constant? - concat!('[', stringify!($ty), "; ", "S") - } - } - &Meta{} - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - let size = std::mem::size_of::<$ty>() * S; - if size > buf.len() { - Err($crate::protocol::ParseError::TooShort) - } else { - Ok(size) - } - } - #[inline(always)] - pub const fn extract(mut buf: &[u8]) -> Result<[$ty; S], $crate::protocol::ParseError> { - let mut out: [$ty; S] = [0; S]; - let mut i = 0; - loop { - if i == S { - break; - } - (out[i], buf) = if let Some((bytes, rest)) = buf.split_first_chunk() { - (<$ty>::from_be_bytes(*bytes), rest) - } else { - return Err($crate::protocol::ParseError::TooShort) - }; - i += 1; - } - Ok(out) - } - #[inline(always)] - pub fn copy_to_buf(mut buf: &mut BufWriter, value: [$ty; S]) { - if !buf.test(std::mem::size_of::<$ty>() * S) { - return; - } - for n in value { - buf.write(&<$ty>::to_be_bytes(n)); - } - } - } - - impl $crate::protocol::FixedSize for $ty { - const SIZE: usize = std::mem::size_of::<$ty>(); - #[inline(always)] - fn extract_infallible(buf: &[u8]) -> $ty { - if let Some(buf) = buf.first_chunk() { - <$ty>::from_be_bytes(*buf) - } else { - panic!() - } - } - } - impl $crate::protocol::FixedSize for [$ty; S] { - const SIZE: usize = std::mem::size_of::<$ty>() * S; - #[inline(always)] - fn extract_infallible(mut buf: &[u8]) -> [$ty; S] { - let mut out: [$ty; S] = [0; S]; - let mut i = 0; - loop { - if i == S { - break; - } - (out[i], buf) = if let Some((bytes, rest)) = buf.split_first_chunk() { - (<$ty>::from_be_bytes(*bytes), rest) - } else { - panic!() - }; - i += 1; - } - out - } - } - - basic_types!(: array<$ty> u8 i16 i32 u32 u64); - )* - }; - - (: array<$ty:ty> $($len:ty)*) => { - $( - #[allow(unused)] - impl FieldAccess> { - pub const fn meta() -> &'static dyn Meta { - &ArrayMeta::<$len, $ty> { _phantom: PhantomData } - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - const N: usize = std::mem::size_of::<$ty>(); - const L: usize = std::mem::size_of::<$len>(); - if let Some(len) = buf.first_chunk::() { - let len_value = <$len>::from_be_bytes(*len); - #[allow(unused_comparisons)] - if len_value < 0 { - return Err($crate::protocol::ParseError::InvalidData); - } - let mut byte_len = len_value as usize; - byte_len = match byte_len.checked_mul(N) { - Some(l) => l, - None => return Err($crate::protocol::ParseError::TooShort), - }; - byte_len = match byte_len.checked_add(L) { - Some(l) => l, - None => return Err($crate::protocol::ParseError::TooShort), - }; - if buf.len() < byte_len { - Err($crate::protocol::ParseError::TooShort) - } else { - Ok(byte_len) - } - } else { - Err($crate::protocol::ParseError::TooShort) - } - } - #[inline(always)] - pub const fn extract(mut buf: &[u8]) -> Result, $crate::protocol::ParseError> { - const N: usize = std::mem::size_of::<$ty>(); - const L: usize = std::mem::size_of::<$len>(); - if let Some((len, array)) = buf.split_first_chunk::() { - let len_value = <$len>::from_be_bytes(*len); - #[allow(unused_comparisons)] - if len_value < 0 { - return Err($crate::protocol::ParseError::InvalidData); - } - let mut byte_len = len_value as usize; - byte_len = match byte_len.checked_mul(N) { - Some(l) => l, - None => return Err($crate::protocol::ParseError::TooShort), - }; - byte_len = match byte_len.checked_add(L) { - Some(l) => l, - None => return Err($crate::protocol::ParseError::TooShort), - }; - if buf.len() < byte_len { - Err($crate::protocol::ParseError::TooShort) - } else { - Ok(Array::new(array, <$len>::from_be_bytes(*len) as u32)) - } - } else { - Err($crate::protocol::ParseError::TooShort) - } - } - #[inline(always)] - pub const fn measure(buffer: &[$ty]) -> usize { - buffer.len() * std::mem::size_of::<$ty>() + std::mem::size_of::<$len>() - } - #[inline(always)] - pub fn copy_to_buf(mut buf: &mut BufWriter, value: &[$ty]) { - let size: usize = std::mem::size_of::<$ty>() * value.len() + std::mem::size_of::<$len>(); - if !buf.test(size) { - return; - } - buf.write(&<$len>::to_be_bytes(value.len() as _)); - for n in value { - buf.write(&<$ty>::to_be_bytes(*n)); - } - } - } - )* - } -} -basic_types!(u8 i16 i32 u32 u64); diff --git a/rust/pgrust/src/protocol/definition.rs b/rust/pgrust/src/protocol/definition.rs deleted file mode 100644 index bef36d5aa65..00000000000 --- a/rust/pgrust/src/protocol/definition.rs +++ /dev/null @@ -1,740 +0,0 @@ -use super::gen::protocol; -use super::message_group::message_group; -use crate::protocol::meta::*; - -message_group!( - /// The `Backend` message group contains messages sent from the backend to the frontend. - Backend: Message = [ - AuthenticationOk, - AuthenticationKerberosV5, - AuthenticationCleartextPassword, - AuthenticationMD5Password, - AuthenticationGSS, - AuthenticationGSSContinue, - AuthenticationSSPI, - AuthenticationSASL, - AuthenticationSASLContinue, - AuthenticationSASLFinal, - BackendKeyData, - BindComplete, - CloseComplete, - CommandComplete, - CopyData, - CopyDone, - CopyInResponse, - CopyOutResponse, - CopyBothResponse, - DataRow, - EmptyQueryResponse, - ErrorResponse, - FunctionCallResponse, - NegotiateProtocolVersion, - NoData, - NoticeResponse, - NotificationResponse, - ParameterDescription, - ParameterStatus, - ParseComplete, - PortalSuspended, - ReadyForQuery, - RowDescription - ] -); - -message_group!( - /// The `Frontend` message group contains messages sent from the frontend to the backend. - Frontend: Message = [ - Bind, - Close, - CopyData, - CopyDone, - CopyFail, - Describe, - Execute, - Flush, - FunctionCall, - GSSResponse, - Parse, - PasswordMessage, - Query, - SASLInitialResponse, - SASLResponse, - Sync, - Terminate - ] -); - -message_group!( - /// The `Initial` message group contains messages that are sent before the - /// normal message flow. - Initial: InitialMessage = [ - CancelRequest, - GSSENCRequest, - SSLRequest, - StartupMessage - ] -); - -protocol!( - -/// A generic base for all Postgres mtype/mlen-style messages. -struct Message { - /// Identifies the message. - mtype: u8, - /// Length of message contents in bytes, including self. - mlen: len, - /// Message contents. - data: Rest, -} - -/// A generic base for all initial Postgres messages. -struct InitialMessage { - /// Length of message contents in bytes, including self. - mlen: len, - /// The identifier for this initial message. - protocol_version: i32, - /// Message contents. - data: Rest -} - -/// The `AuthenticationMessage` struct is a base for all Postgres authentication messages. -struct AuthenticationMessage: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len, - /// Specifies that the authentication was successful. - status: i32, -} - -/// The `AuthenticationOk` struct represents a message indicating successful authentication. -struct AuthenticationOk: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// Specifies that the authentication was successful. - status: i32 = 0, -} - -/// The `AuthenticationKerberosV5` struct represents a message indicating that Kerberos V5 authentication is required. -struct AuthenticationKerberosV5: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// Specifies that Kerberos V5 authentication is required. - status: i32 = 2, -} - -/// The `AuthenticationCleartextPassword` struct represents a message indicating that a cleartext password is required for authentication. -struct AuthenticationCleartextPassword: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// Specifies that a clear-text password is required. - status: i32 = 3, -} - -/// The `AuthenticationMD5Password` struct represents a message indicating that an MD5-encrypted password is required for authentication. -struct AuthenticationMD5Password: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 12, - /// Specifies that an MD5-encrypted password is required. - status: i32 = 5, - /// The salt to use when encrypting the password. - salt: [u8; 4], -} - -/// The `AuthenticationSCMCredential` struct represents a message indicating that an SCM credential is required for authentication. -struct AuthenticationSCMCredential: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 6, - /// Any data byte, which is ignored. - byte: u8 = 0, -} - -/// The `AuthenticationGSS` struct represents a message indicating that GSSAPI authentication is required. -struct AuthenticationGSS: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// Specifies that GSSAPI authentication is required. - status: i32 = 7, -} - -/// The `AuthenticationGSSContinue` struct represents a message indicating the continuation of GSSAPI authentication. -struct AuthenticationGSSContinue: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len, - /// Specifies that this message contains GSSAPI or SSPI data. - status: i32 = 8, - /// GSSAPI or SSPI authentication data. - data: Rest, -} - -/// The `AuthenticationSSPI` struct represents a message indicating that SSPI authentication is required. -struct AuthenticationSSPI: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// Specifies that SSPI authentication is required. - status: i32 = 9, -} - -/// The `AuthenticationSASL` struct represents a message indicating that SASL authentication is required. -struct AuthenticationSASL: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len, - /// Specifies that SASL authentication is required. - status: i32 = 10, - /// List of SASL authentication mechanisms, terminated by a zero byte. - mechanisms: ZTArray, -} - -/// The `AuthenticationSASLContinue` struct represents a message containing a SASL challenge during the authentication process. -struct AuthenticationSASLContinue: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len, - /// Specifies that this message contains a SASL challenge. - status: i32 = 11, - /// SASL data, specific to the SASL mechanism being used. - data: Rest, -} - -/// The `AuthenticationSASLFinal` struct represents a message indicating the completion of SASL authentication. -struct AuthenticationSASLFinal: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len, - /// Specifies that SASL authentication has completed. - status: i32 = 12, - /// SASL outcome "additional data", specific to the SASL mechanism being used. - data: Rest, -} - -/// The `BackendKeyData` struct represents a message containing the process ID and secret key for this backend. -struct BackendKeyData: Message { - /// Identifies the message as cancellation key data. - mtype: u8 = 'K', - /// Length of message contents in bytes, including self. - mlen: len = 12, - /// The process ID of this backend. - pid: i32, - /// The secret key of this backend. - key: i32, -} - -/// The `Bind` struct represents a message to bind a named portal to a prepared statement. -struct Bind: Message { - /// Identifies the message as a Bind command. - mtype: u8 = 'B', - /// Length of message contents in bytes, including self. - mlen: len, - /// The name of the destination portal. - portal: ZTString, - /// The name of the source prepared statement. - statement: ZTString, - /// The parameter format codes. - format_codes: Array, - /// Array of parameter values and their lengths. - values: Array, - /// The result-column format codes. - result_format_codes: Array, -} - -/// The `BindComplete` struct represents a message indicating that a Bind operation was successful. -struct BindComplete: Message { - /// Identifies the message as a Bind-complete indicator. - mtype: u8 = '2', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `CancelRequest` struct represents a message to request the cancellation of a query. -struct CancelRequest: InitialMessage { - /// Length of message contents in bytes, including self. - mlen: len = 16, - /// The cancel request code. - code: i32 = 80877102, - /// The process ID of the target backend. - pid: i32, - /// The secret key for the target backend. - key: i32, -} - -/// The `Close` struct represents a message to close a prepared statement or portal. -struct Close: Message { - /// Identifies the message as a Close command. - mtype: u8 = 'C', - /// Length of message contents in bytes, including self. - mlen: len, - /// 'xS' to close a prepared statement; 'P' to close a portal. - ctype: u8, - /// The name of the prepared statement or portal to close. - name: ZTString, -} - -/// The `CloseComplete` struct represents a message indicating that a Close operation was successful. -struct CloseComplete: Message { - /// Identifies the message as a Close-complete indicator. - mtype: u8 = '3', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `CommandComplete` struct represents a message indicating the successful completion of a command. -struct CommandComplete: Message { - /// Identifies the message as a command-completed response. - mtype: u8 = 'C', - /// Length of message contents in bytes, including self. - mlen: len, - /// The command tag. - tag: ZTString, -} - -/// The `CopyData` struct represents a message containing data for a copy operation. -struct CopyData: Message { - /// Identifies the message as COPY data. - mtype: u8 = 'd', - /// Length of message contents in bytes, including self. - mlen: len, - /// Data that forms part of a COPY data stream. - data: Rest, -} - -/// The `CopyDone` struct represents a message indicating that a copy operation is complete. -struct CopyDone: Message { - /// Identifies the message as a COPY-complete indicator. - mtype: u8 = 'c', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `CopyFail` struct represents a message indicating that a copy operation has failed. -struct CopyFail: Message { - /// Identifies the message as a COPY-failure indicator. - mtype: u8 = 'f', - /// Length of message contents in bytes, including self. - mlen: len, - /// An error message to report as the cause of failure. - error_msg: ZTString, -} - -/// The `CopyInResponse` struct represents a message indicating that the server is ready to receive data for a copy-in operation. -struct CopyInResponse: Message { - /// Identifies the message as a Start Copy In response. - mtype: u8 = 'G', - /// Length of message contents in bytes, including self. - mlen: len, - /// 0 for textual, 1 for binary. - format: u8, - /// The format codes for each column. - format_codes: Array, -} - -/// The `CopyOutResponse` struct represents a message indicating that the server is ready to send data for a copy-out operation. -struct CopyOutResponse: Message { - /// Identifies the message as a Start Copy Out response. - mtype: u8 = 'H', - /// Length of message contents in bytes, including self. - mlen: len, - /// 0 for textual, 1 for binary. - format: u8, - /// The format codes for each column. - format_codes: Array, -} - -/// The `CopyBothResponse` is used only for Streaming Replication. -struct CopyBothResponse: Message { - /// Identifies the message as a Start Copy Both response. - mtype: u8 = 'W', - /// Length of message contents in bytes, including self. - mlen: len, - /// 0 for textual, 1 for binary. - format: u8, - /// The format codes for each column. - format_codes: Array, -} - -/// The `DataRow` struct represents a message containing a row of data. -struct DataRow: Message { - /// Identifies the message as a data row. - mtype: u8 = 'D', - /// Length of message contents in bytes, including self. - mlen: len, - /// Array of column values and their lengths. - values: Array, -} - -/// The `Describe` struct represents a message to describe a prepared statement or portal. -struct Describe: Message { - /// Identifies the message as a Describe command. - mtype: u8 = 'D', - /// Length of message contents in bytes, including self. - mlen: len, - /// 'S' to describe a prepared statement; 'P' to describe a portal. - dtype: u8, - /// The name of the prepared statement or portal. - name: ZTString, -} - -/// The `EmptyQueryResponse` struct represents a message indicating that an empty query string was recognized. -struct EmptyQueryResponse: Message { - /// Identifies the message as a response to an empty query String. - mtype: u8 = 'I', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `ErrorResponse` struct represents a message indicating that an error has occurred. -struct ErrorResponse: Message { - /// Identifies the message as an error. - mtype: u8 = 'E', - /// Length of message contents in bytes, including self. - mlen: len, - /// Array of error fields and their values. - fields: ZTArray, -} - -/// The `ErrorField` struct represents a single error message within an `ErrorResponse`. -struct ErrorField { - /// A code identifying the field type. - etype: u8, - /// The field value. - value: ZTString, -} - -/// The `Execute` struct represents a message to execute a prepared statement or portal. -struct Execute: Message { - /// Identifies the message as an Execute command. - mtype: u8 = 'E', - /// Length of message contents in bytes, including self. - mlen: len, - /// The name of the portal to execute. - portal: ZTString, - /// Maximum number of rows to return. - max_rows: i32, -} - -/// The `Flush` struct represents a message to flush the backend's output buffer. -struct Flush: Message { - /// Identifies the message as a Flush command. - mtype: u8 = 'H', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `FunctionCall` struct represents a message to call a function. -struct FunctionCall: Message { - /// Identifies the message as a function call. - mtype: u8 = 'F', - /// Length of message contents in bytes, including self. - mlen: len, - /// OID of the function to execute. - function_id: i32, - /// The parameter format codes. - format_codes: Array, - /// Array of args and their lengths. - args: Array, - /// The format code for the result. - result_format_code: i16, -} - -/// The `FunctionCallResponse` struct represents a message containing the result of a function call. -struct FunctionCallResponse: Message { - /// Identifies the message as a function-call response. - mtype: u8 = 'V', - /// Length of message contents in bytes, including self. - mlen: len, - /// The function result value. - result: Encoded, -} - -/// The `GSSENCRequest` struct represents a message requesting GSSAPI encryption. -struct GSSENCRequest: InitialMessage { - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// The GSSAPI Encryption request code. - gssenc_request_code: i32 = 80877104, -} - -/// The `GSSResponse` struct represents a message containing a GSSAPI or SSPI response. -struct GSSResponse: Message { - /// Identifies the message as a GSSAPI or SSPI response. - mtype: u8 = 'p', - /// Length of message contents in bytes, including self. - mlen: len, - /// GSSAPI or SSPI authentication data. - data: Rest, -} - -/// The `NegotiateProtocolVersion` struct represents a message requesting protocol version negotiation. -struct NegotiateProtocolVersion: Message { - /// Identifies the message as a protocol version negotiation request. - mtype: u8 = 'v', - /// Length of message contents in bytes, including self. - mlen: len, - /// Newest minor protocol version supported by the server. - minor_version: i32, - /// List of protocol options not recognized. - options: Array, -} - -/// The `NoData` struct represents a message indicating that there is no data to return. -struct NoData: Message { - /// Identifies the message as a No Data indicator. - mtype: u8 = 'n', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `NoticeResponse` struct represents a message containing a notice. -struct NoticeResponse: Message { - /// Identifies the message as a notice. - mtype: u8 = 'N', - /// Length of message contents in bytes, including self. - mlen: len, - /// Array of notice fields and their values. - fields: ZTArray, -} - -/// The `NoticeField` struct represents a single error message within an `NoticeResponse`. -struct NoticeField: Message { - /// A code identifying the field type. - ntype: u8, - /// The field value. - value: ZTString, -} - -/// The `NotificationResponse` struct represents a message containing a notification from the backend. -struct NotificationResponse: Message { - /// Identifies the message as a notification. - mtype: u8 = 'A', - /// Length of message contents in bytes, including self. - mlen: len, - /// The process ID of the notifying backend. - pid: i32, - /// The name of the notification channel. - channel: ZTString, - /// The notification payload. - payload: ZTString, -} - -/// The `ParameterDescription` struct represents a message describing the parameters needed by a prepared statement. -struct ParameterDescription: Message { - /// Identifies the message as a parameter description. - mtype: u8 = 't', - /// Length of message contents in bytes, including self. - mlen: len, - /// OIDs of the parameter data types. - param_types: Array, -} - -/// The `ParameterStatus` struct represents a message containing the current status of a parameter. -struct ParameterStatus: Message { - /// Identifies the message as a runtime parameter status report. - mtype: u8 = 'S', - /// Length of message contents in bytes, including self. - mlen: len, - /// The name of the parameter. - name: ZTString, - /// The current value of the parameter. - value: ZTString, -} - -/// The `Parse` struct represents a message to parse a query string. -struct Parse: Message { - /// Identifies the message as a Parse command. - mtype: u8 = 'P', - /// Length of message contents in bytes, including self. - mlen: len, - /// The name of the destination prepared statement. - statement: ZTString, - /// The query String to be parsed. - query: ZTString, - /// OIDs of the parameter data types. - param_types: Array, -} - -/// The `ParseComplete` struct represents a message indicating that a Parse operation was successful. -struct ParseComplete: Message { - /// Identifies the message as a Parse-complete indicator. - mtype: u8 = '1', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `PasswordMessage` struct represents a message containing a password. -struct PasswordMessage: Message { - /// Identifies the message as a password response. - mtype: u8 = 'p', - /// Length of message contents in bytes, including self. - mlen: len, - /// The password (encrypted or plaintext, depending on context). - password: ZTString, -} - -/// The `PortalSuspended` struct represents a message indicating that a portal has been suspended. -struct PortalSuspended: Message { - /// Identifies the message as a portal-suspended indicator. - mtype: u8 = 's', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `Query` struct represents a message to execute a simple query. -struct Query: Message { - /// Identifies the message as a simple query command. - mtype: u8 = 'Q', - /// Length of message contents in bytes, including self. - mlen: len, - /// The query String to be executed. - query: ZTString, -} - -/// The `ReadyForQuery` struct represents a message indicating that the backend is ready for a new query. -struct ReadyForQuery: Message { - /// Identifies the message as a ready-for-query indicator. - mtype: u8 = 'Z', - /// Length of message contents in bytes, including self. - mlen: len = 5, - /// Current transaction status indicator. - status: u8, -} - -/// The `RowDescription` struct represents a message describing the rows that will be returned by a query. -struct RowDescription: Message { - /// Identifies the message as a row description. - mtype: u8 = 'T', - /// Length of message contents in bytes, including self. - mlen: len, - /// Array of field descriptions. - fields: Array, -} - -/// The `RowField` struct represents a row within the `RowDescription` message. -struct RowField { - /// The field name - name: ZTString, - /// The table ID (OID) of the table the column is from, or 0 if not a column reference - table_oid: i32, - /// The attribute number of the column, or 0 if not a column reference - column_attr_number: i16, - /// The object ID of the field's data type - data_type_oid: i32, - /// The data type size (negative if variable size) - data_type_size: i16, - /// The type modifier - type_modifier: i32, - /// The format code being used for the field (0 for text, 1 for binary) - format_code: i16, -} - -/// The `SASLInitialResponse` struct represents a message containing a SASL initial response. -struct SASLInitialResponse: Message { - /// Identifies the message as a SASL initial response. - mtype: u8 = 'p', - /// Length of message contents in bytes, including self. - mlen: len, - /// Name of the SASL authentication mechanism. - mechanism: ZTString, - /// SASL initial response data. - response: Array, -} - -/// The `SASLResponse` struct represents a message containing a SASL response. -struct SASLResponse: Message { - /// Identifies the message as a SASL response. - mtype: u8 = 'p', - /// Length of message contents in bytes, including self. - mlen: len, - /// SASL response data. - response: Rest, -} - -/// The `SSLRequest` struct represents a message requesting SSL encryption. -struct SSLRequest: InitialMessage { - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// The SSL request code. - code: i32 = 80877103, -} - -struct SSLResponse { - /// Specifies if SSL was accepted or rejected. - code: u8, -} - -/// The `StartupMessage` struct represents a message to initiate a connection. -struct StartupMessage: InitialMessage { - /// Length of message contents in bytes, including self. - mlen: len, - /// The protocol version number. - protocol: i32 = 196608, - /// List of parameter name-value pairs, terminated by a zero byte. - params: ZTArray, -} - -/// The `StartupMessage` struct represents a name/value pair within the `StartupMessage` message. -struct StartupNameValue { - /// The parameter name. - name: ZTString, - /// The parameter value. - value: ZTString, -} - -/// The `Sync` struct represents a message to synchronize the frontend and backend. -struct Sync: Message { - /// Identifies the message as a Sync command. - mtype: u8 = 'S', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `Terminate` struct represents a message to terminate a connection. -struct Terminate: Message { - /// Identifies the message as a Terminate command. - mtype: u8 = 'X', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} -); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_all() { - let message = meta::Message::default(); - let initial_message = meta::InitialMessage::default(); - - for meta in meta::ALL { - eprintln!("{meta:#?}"); - if **meta != message && **meta != initial_message { - if meta.field("mtype").is_some() && meta.field("mlen").is_some() { - // If a message has mtype and mlen, it should subclass Message - assert_eq!(*meta.parent().unwrap(), message); - } else if meta.field("mlen").is_some() { - // If a message has mlen only, it should subclass InitialMessage - assert_eq!(*meta.parent().unwrap(), initial_message); - } - } - } - } -} diff --git a/rust/pgrust/src/protocol/edgedb.rs b/rust/pgrust/src/protocol/edgedb.rs index 98a25c2155b..da32a8d6367 100644 --- a/rust/pgrust/src/protocol/edgedb.rs +++ b/rust/pgrust/src/protocol/edgedb.rs @@ -1,5 +1,5 @@ -use super::gen::protocol; -use crate::protocol::message_group::message_group; +use db_proto::{message_group, protocol}; + message_group!( EdgeDBBackend: Message = [ AuthenticationOk, diff --git a/rust/pgrust/src/protocol/mod.rs b/rust/pgrust/src/protocol/mod.rs index aead9728e4e..c09f9c97449 100644 --- a/rust/pgrust/src/protocol/mod.rs +++ b/rust/pgrust/src/protocol/mod.rs @@ -1,168 +1,10 @@ -mod arrays; -mod buffer; -mod datatypes; pub mod edgedb; -mod gen; -mod message_group; pub mod postgres; -mod writer; - -/// Metatypes for the protocol and related arrays/strings. -pub mod meta { - pub use super::arrays::meta::*; - pub use super::datatypes::meta::*; -} - -#[allow(unused)] -pub use arrays::{Array, ArrayIter, ZTArray, ZTArrayIter}; -pub use buffer::StructBuffer; -#[allow(unused)] -pub use datatypes::{Encoded, LString, Rest, ZTString}; -pub use message_group::match_message; -pub use writer::BufWriter; - -#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq, Eq)] -pub enum ParseError { - #[error("Buffer is too short")] - TooShort, - #[error("Invalid data")] - InvalidData, -} - -/// Implemented for all structs. -pub trait StructMeta { - type Struct<'a>: std::fmt::Debug; - fn new(buf: &[u8]) -> Result, ParseError>; - fn to_vec(s: &Self::Struct<'_>) -> Vec; -} - -/// Implemented for all generated structs that have a [`meta::Length`] field at a fixed offset. -pub trait StructLength: StructMeta { - fn length_field_of(of: &Self::Struct<'_>) -> usize; - fn length_field_offset() -> usize; - fn length_of_buf(buf: &[u8]) -> Option { - if buf.len() < Self::length_field_offset() + std::mem::size_of::() { - None - } else { - let len = FieldAccess::::extract( - &buf[Self::length_field_offset() - ..Self::length_field_offset() + std::mem::size_of::()], - ) - .ok()?; - Some(Self::length_field_offset() + len) - } - } -} - -/// For a given metaclass, returns the inflated type, a measurement type and a -/// builder type. -pub trait Enliven { - type WithLifetime<'a>; - type ForMeasure<'a>: 'a; - type ForBuilder<'a>: 'a; -} - -pub trait FixedSize: Enliven { - const SIZE: usize; - /// Extract this type from the given buffer, assuming that enough bytes are available. - fn extract_infallible(buf: &[u8]) -> ::WithLifetime<'_>; -} - -#[derive(Debug, Eq, PartialEq)] -pub enum MetaRelation { - Parent, - Length, - Item, - Field(&'static str), -} - -pub trait Meta { - fn name(&self) -> &'static str { - std::any::type_name::() - } - fn relations(&self) -> &'static [(MetaRelation, &'static dyn Meta)] { - &[] - } - fn field(&self, name: &'static str) -> Option<&'static dyn Meta> { - for (relation, meta) in self.relations() { - if relation == &MetaRelation::Field(name) { - return Some(*meta); - } - } - None - } - fn parent(&self) -> Option<&'static dyn Meta> { - for (relation, meta) in self.relations() { - if relation == &MetaRelation::Parent { - return Some(*meta); - } - } - None - } -} - -impl PartialEq for dyn Meta { - fn eq(&self, other: &T) -> bool { - other.name() == self.name() - } -} - -impl std::fmt::Debug for dyn Meta { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut s = f.debug_struct(self.name()); - for (relation, meta) in self.relations() { - if relation == &MetaRelation::Parent { - s.field(&format!("{relation:?}"), &meta.name()); - } else { - s.field(&format!("{relation:?}"), meta); - } - } - s.finish() - } -} - -/// Delegates to a concrete [`FieldAccess`] but as a non-const trait. This is -/// used for performing extraction in iterators. -pub(crate) trait FieldAccessArray: Enliven { - const META: &'static dyn Meta; - fn size_of_field_at(buf: &[u8]) -> Result; - fn extract(buf: &[u8]) -> Result<::WithLifetime<'_>, ParseError>; -} - -/// This struct is specialized for each type we want to extract data from. We -/// have to do it this way to work around Rust's lack of const specialization. -pub(crate) struct FieldAccess { - _phantom_data: std::marker::PhantomData, -} - -/// Delegate to the concrete [`FieldAccess`] for each type we want to extract. -macro_rules! field_access { - ($ty:ty) => { - impl $crate::protocol::FieldAccessArray for $ty { - const META: &'static dyn $crate::protocol::Meta = - $crate::protocol::FieldAccess::<$ty>::meta(); - #[inline(always)] - fn size_of_field_at(buf: &[u8]) -> Result { - $crate::protocol::FieldAccess::<$ty>::size_of_field_at(buf) - } - #[inline(always)] - fn extract( - buf: &[u8], - ) -> Result< - ::WithLifetime<'_>, - $crate::protocol::ParseError, - > { - $crate::protocol::FieldAccess::<$ty>::extract(buf) - } - } - }; -} -pub(crate) use field_access; #[cfg(test)] mod tests { use super::*; - use buffer::StructBuffer; + use db_proto::{match_message, Encoded, StructBuffer, StructMeta}; use postgres::{builder, data::*, measure, meta}; use rand::Rng; /// We want to ensure that no malformed messages will cause unexpected @@ -610,6 +452,17 @@ mod tests { fuzz_test::(message); } + #[test] + fn test_datarow() { + let buf = [ + 0x44, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, + ]; + assert!(DataRow::is_buffer(&buf)); + let message = DataRow::new(&buf).unwrap(); + assert_eq!(message.values().len(), 1); + assert_eq!(message.values().into_iter().next().unwrap(), Encoded::Null); + } + #[test] fn test_edgedb_sasl() { use crate::protocol::edgedb::*; diff --git a/rust/pgrust/src/protocol/postgres.rs b/rust/pgrust/src/protocol/postgres.rs index 04bbdc106f5..bccc5eb126d 100644 --- a/rust/pgrust/src/protocol/postgres.rs +++ b/rust/pgrust/src/protocol/postgres.rs @@ -1,5 +1,4 @@ -use super::gen::protocol; -use super::message_group::message_group; +use db_proto::{message_group, protocol}; message_group!( /// The `Backend` message group contains messages sent from the backend to the frontend. @@ -282,7 +281,7 @@ struct Close: Message { mtype: u8 = 'C', /// Length of message contents in bytes, including self. mlen: len, - /// 'xS' to close a prepared statement; 'P' to close a portal. + /// 'S' to close a prepared statement; 'P' to close a portal. ctype: u8, /// The name of the prepared statement or portal to close. name: ZTString, @@ -564,7 +563,7 @@ struct Parse: Message { mlen: len, /// The name of the destination prepared statement. statement: ZTString, - /// The query String to be parsed. + /// The query string to be parsed. query: ZTString, /// OIDs of the parameter data types. param_types: Array, diff --git a/rust/pgrust/src/python.rs b/rust/pgrust/src/python.rs index 0914e181367..6cb61da2bc1 100644 --- a/rust/pgrust/src/python.rs +++ b/rust/pgrust/src/python.rs @@ -11,11 +11,9 @@ use crate::{ }, ConnectionSslRequirement, }, - protocol::{ - postgres::{data::SSLResponse, meta, FrontendBuilder, InitialBuilder}, - StructBuffer, - }, + protocol::postgres::{data::SSLResponse, meta, FrontendBuilder, InitialBuilder}, }; +use db_proto::StructBuffer; use pyo3::{ buffer::PyBuffer, exceptions::{PyException, PyRuntimeError}, @@ -50,12 +48,6 @@ impl From for PyErr { } } -impl From for PyErr { - fn from(err: crate::protocol::ParseError) -> PyErr { - PyRuntimeError::new_err(err.to_string()) - } -} - impl EnvVar for (String, Bound<'_, PyAny>) { fn read(&self, name: &'static str) -> Option> { // os.environ[name], or the default user if not @@ -359,7 +351,8 @@ impl PyConnectionState { if self.inner.read_ssl_response() { // SSL responses are always one character let response = [buffer.as_slice(py).unwrap().first().unwrap().get()]; - let response = SSLResponse::new(&response)?; + let response = + SSLResponse::new(&response).map_err(|e| PyException::new_err(e.to_string()))?; self.inner .drive(ConnectionDrive::SslResponse(response), &mut self.update)?; } else { diff --git a/rust/pgrust/tests/query_real_postgres.rs b/rust/pgrust/tests/query_real_postgres.rs new file mode 100644 index 00000000000..ce6a9bfa41f --- /dev/null +++ b/rust/pgrust/tests/query_real_postgres.rs @@ -0,0 +1,354 @@ +use std::cell::RefCell; +use std::future::Future; +use std::num::NonZero; +use std::rc::Rc; + +// Constants +use db_proto::match_message; +use gel_auth::AuthType; +use pgrust::connection::tokio::TokioStream; +use pgrust::connection::{ + Client, Credentials, FlowAccumulator, MaxRows, Oid, Param, PipelineBuilder, Portal, + ResolvedTarget, Statement, +}; +use pgrust::protocol::postgres::data::*; +use tokio::task::LocalSet; + +use captive_postgres::*; + +fn address(address: &ListenAddress) -> ResolvedTarget { + match address { + ListenAddress::Tcp(addr) => ResolvedTarget::SocketAddr(*addr), + #[cfg(unix)] + ListenAddress::Unix(path) => ResolvedTarget::UnixSocketAddr( + std::os::unix::net::SocketAddr::from_pathname(path).unwrap(), + ), + } +} + +async fn with_postgres(callback: F) -> Result, Box> +where + F: FnOnce(Client, Rc>) -> R, + R: Future>>, +{ + let Some(postgres_process) = setup_postgres(AuthType::Trust, Mode::Tcp)? else { + return Ok(None); + }; + + let credentials = Credentials { + username: DEFAULT_USERNAME.to_string(), + password: DEFAULT_PASSWORD.to_string(), + database: DEFAULT_DATABASE.to_string(), + server_settings: Default::default(), + }; + + let socket = address(&postgres_process.socket_address).connect().await?; + let (client, task) = Client::new(credentials, socket, ()); + let accumulator = Rc::new(RefCell::new(FlowAccumulator::default())); + + let accumulator2 = accumulator.clone(); + LocalSet::new() + .run_until(async move { + tokio::task::spawn_local(task); + client.ready().await?; + callback(client, accumulator2.clone()).await?; + Result::<(), Box>::Ok(()) + }) + .await?; + + let mut s = String::new(); + accumulator.borrow().with_messages(|message| { + match_message!(Ok(message), Backend { + (ParameterDescription as params) => { + // OID values are not guaranteed to be stable, so we just print "..." instead. + s.push_str(&format!("ParameterDescription {:?}\n", params.param_types().into_iter().map(|_| "...").collect::>())); + }, + (RowDescription as rows) => { + s.push_str(&format!("RowDescription {}\n", rows.fields().into_iter().map(|f| f.name().to_string_lossy().into_owned()).collect::>().join(", "))); + }, + (PortalSuspended) => { + s.push_str("PortalSuspended\n"); + }, + (ErrorResponse as err) => { + for field in err.fields() { + if field.etype() as char == 'C' { + s.push_str(&format!("ErrorResponse {}\n", field.value().to_string_lossy())); + return; + } + } + s.push_str(&format!("ErrorResponse {:?}\n", err)); + }, + (NoticeResponse as notice) => { + for field in notice.fields() { + if field.ntype() as char == 'M' { + s.push_str(&format!("NoticeResponse {}\n", field.value().to_string_lossy())); + return; + } + } + s.push_str(&format!("NoticeResponse {:?}\n", notice)); + }, + (CommandComplete as cmd) => { + s.push_str(&format!("CommandComplete {:?}\n", cmd.tag())); + }, + (DataRow as row) => { + s.push_str(&format!("DataRow {}\n", row.values().into_iter().map(|v| v.to_string_lossy().into_owned()).collect::>().join(", "))); + }, + (CopyData as copy_data) => { + s.push_str(&format!("CopyData {:?}\n", String::from_utf8_lossy(©_data.data()))); + }, + (CopyOutResponse as copy_out) => { + s.push_str(&format!("CopyOutResponse {}\n", copy_out.format())); + }, + _unknown => { + s.push_str("Unknown\n"); + } + }) + }); + + Ok(Some(s)) +} + +#[test_log::test(tokio::test)] +async fn test_query() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client.query("SELECT 1", accumulator.clone()).await?; + Ok(()) + }) + .await? + { + assert_eq!( + s, + "RowDescription ?column?\nDataRow 1\nCommandComplete \"SELECT 1\"\n" + ); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_extended_query_success() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .parse( + Statement("test"), + "SELECT $1", + &[Oid::unspecified()], + accumulator.clone(), + ) + .describe_statement(Statement("test"), accumulator.clone()) + .bind( + Portal("test"), + Statement("test"), + &[Param::Text("1")], + &[], + accumulator.clone(), + ) + .describe_portal(Portal("test"), accumulator.clone()) + .execute( + Portal("test"), + MaxRows::Limited(NonZero::new(1).unwrap()), + accumulator.clone(), + ) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, "ParameterDescription [\"...\"]\nRowDescription ?column?\nRowDescription ?column?\nDataRow 1\nPortalSuspended\n"); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_extended_query_parse_error() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .parse(Statement("test"), ".", &[], accumulator.clone()) + .bind( + Portal("test"), + Statement("test"), + &[], + &[], + accumulator.clone(), + ) + .query("SELECT 1", accumulator.clone()) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, "ErrorResponse 42601\n"); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_extended_query_portal_suspended() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .parse( + Statement("test"), + "SELECT generate_series(1,3)", + &[], + accumulator.clone(), + ) + .bind( + Portal("test"), + Statement("test"), + &[], + &[], + accumulator.clone(), + ) + .execute( + Portal("test"), + MaxRows::Limited(NonZero::new(2).unwrap()), + accumulator.clone(), + ) + .execute( + Portal("test"), + MaxRows::Limited(NonZero::new(2).unwrap()), + accumulator.clone(), + ) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!( + s, + "DataRow 1\nDataRow 2\nPortalSuspended\nDataRow 3\nCommandComplete \"SELECT 1\"\n" + ); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_extended_query_copy() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .parse( + Statement("test"), + "COPY (SELECT 1) TO STDOUT", + &[], + accumulator.clone(), + ) + .bind( + Portal("test"), + Statement("test"), + &[], + &[], + accumulator.clone(), + ) + .execute(Portal("test"), MaxRows::Unlimited, accumulator.clone()) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!( + s, + "CopyOutResponse 0\nCopyData \"1\\n\"\nCommandComplete \"COPY 1\"\n" + ); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_extended_query_empty() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .parse(Statement("test"), "", &[], accumulator.clone()) + .bind( + Portal("test"), + Statement("test"), + &[], + &[], + accumulator.clone(), + ) + .execute(Portal("test"), MaxRows::Unlimited, accumulator.clone()) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, ""); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_query_notice() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + // DO block with NOTICE RAISE generates a notice + client + .query( + "DO $$ BEGIN RAISE NOTICE 'test notice'; END $$;", + accumulator.clone(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, "NoticeResponse test notice\nCommandComplete \"DO\"\n"); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_query_warning() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + // DO block with WARNING RAISE generates a warning + client + .query( + "DO $$ BEGIN RAISE WARNING 'test warning'; END $$;", + accumulator.clone(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, "NoticeResponse test warning\nCommandComplete \"DO\"\n"); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_double_begin_transaction() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .query("BEGIN TRANSACTION", accumulator.clone()) + .query("BEGIN TRANSACTION", accumulator.clone()) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, "CommandComplete \"BEGIN\"\nNoticeResponse there is already a transaction in progress\nCommandComplete \"BEGIN\"\n"); + } + Ok(()) +} diff --git a/rust/pgrust/tests/real_postgres.rs b/rust/pgrust/tests/real_postgres.rs index d9238e6f5d9..0d562e2047e 100644 --- a/rust/pgrust/tests/real_postgres.rs +++ b/rust/pgrust/tests/real_postgres.rs @@ -1,378 +1,22 @@ // Constants use gel_auth::AuthType; -use openssl::ssl::{Ssl, SslContext, SslMethod}; -use pgrust::connection::dsn::{Host, HostType}; use pgrust::connection::{connect_raw_ssl, ConnectionError, Credentials, ResolvedTarget}; use pgrust::errors::PgServerError; use pgrust::handshake::ConnectionSslRequirement; use rstest::rstest; -use std::io::{BufRead, BufReader, Write}; -use std::net::{Ipv4Addr, SocketAddr, TcpListener}; -use std::os::unix::fs::PermissionsExt; -use std::path::{Path, PathBuf}; -use std::process::{Command, Stdio}; -use std::sync::{Arc, RwLock}; -use std::thread; -use std::time::{Duration, Instant}; -use tempfile::TempDir; -const STARTUP_TIMEOUT_DURATION: Duration = Duration::from_secs(30); -const PORT_RELEASE_TIMEOUT: Duration = Duration::from_secs(30); -const LINGER_DURATION: Duration = Duration::from_secs(1); -const HOT_LOOP_INTERVAL: Duration = Duration::from_millis(100); -const DEFAULT_USERNAME: &str = "username"; -const DEFAULT_PASSWORD: &str = "password"; -const DEFAULT_DATABASE: &str = "postgres"; +use captive_postgres::*; -/// Represents an ephemeral port that can be allocated and released for immediate re-use by another process. -struct EphemeralPort { - port: u16, - listener: Option, -} - -impl EphemeralPort { - /// Allocates a new ephemeral port. - /// - /// Returns a Result containing the EphemeralPort if successful, - /// or an IO error if the allocation fails. - fn allocate() -> std::io::Result { - let socket = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; - socket.set_reuse_address(true)?; - socket.set_reuse_port(true)?; - socket.set_linger(Some(LINGER_DURATION))?; - socket.bind(&std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, 0)).into())?; - socket.listen(1)?; - let listener = TcpListener::from(socket); - let port = listener.local_addr()?.port(); - Ok(EphemeralPort { - port, - listener: Some(listener), - }) - } - - /// Consumes the EphemeralPort and returns the allocated port number. - fn take(self) -> u16 { - // Drop the listener to free up the port - drop(self.listener); - - // Loop until the port is free - let start = Instant::now(); - - // If we can successfully connect to the port, it's not fully closed - while start.elapsed() < PORT_RELEASE_TIMEOUT { - let res = std::net::TcpStream::connect((Ipv4Addr::LOCALHOST, self.port)); - if res.is_err() { - // If connection fails, the port is released - break; - } - std::thread::sleep(HOT_LOOP_INTERVAL); - } - - self.port +fn address(address: &ListenAddress) -> ResolvedTarget { + match address { + ListenAddress::Tcp(addr) => ResolvedTarget::SocketAddr(*addr), + #[cfg(unix)] + ListenAddress::Unix(path) => ResolvedTarget::UnixSocketAddr( + std::os::unix::net::SocketAddr::from_pathname(path).unwrap(), + ), } } -struct StdioReader { - output: Arc>, -} - -impl StdioReader { - fn spawn(reader: R, prefix: &'static str) -> Self { - let output = Arc::new(RwLock::new(String::new())); - let output_clone = Arc::clone(&output); - - thread::spawn(move || { - let mut buf_reader = std::io::BufReader::new(reader); - loop { - let mut line = String::new(); - match buf_reader.read_line(&mut line) { - Ok(0) => break, - Ok(_) => { - if let Ok(mut output) = output_clone.write() { - output.push_str(&line); - } - eprint!("[{}]: {}", prefix, line); - } - Err(e) => { - let error_line = format!("Error reading {}: {}\n", prefix, e); - if let Ok(mut output) = output_clone.write() { - output.push_str(&error_line); - } - eprintln!("{}", error_line); - } - } - } - }); - - StdioReader { output } - } - - fn contains(&self, s: &str) -> bool { - if let Ok(output) = self.output.read() { - output.contains(s) - } else { - false - } - } -} - -fn init_postgres(initdb: &Path, data_dir: &Path, auth: AuthType) -> std::io::Result<()> { - let mut pwfile = tempfile::NamedTempFile::new()?; - writeln!(pwfile, "{}", DEFAULT_PASSWORD)?; - let mut command = Command::new(initdb); - command - .arg("-D") - .arg(data_dir) - .arg("-A") - .arg(match auth { - AuthType::Deny => "reject", - AuthType::Trust => "trust", - AuthType::Plain => "password", - AuthType::Md5 => "md5", - AuthType::ScramSha256 => "scram-sha-256", - }) - .arg("--pwfile") - .arg(pwfile.path()) - .arg("-U") - .arg(DEFAULT_USERNAME); - - let output = command.output()?; - - let status = output.status; - let output_str = String::from_utf8_lossy(&output.stdout).to_string(); - let error_str = String::from_utf8_lossy(&output.stderr).to_string(); - - eprintln!("initdb stdout:\n{}", output_str); - eprintln!("initdb stderr:\n{}", error_str); - - if !status.success() { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "initdb command failed", - )); - } - - Ok(()) -} - -fn run_postgres( - postgres_bin: &Path, - data_dir: &Path, - socket_path: &Path, - ssl: Option<(PathBuf, PathBuf)>, - port: u16, -) -> std::io::Result { - let mut command = Command::new(postgres_bin); - command - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .arg("-D") - .arg(data_dir) - .arg("-k") - .arg(socket_path) - .arg("-h") - .arg(Ipv4Addr::LOCALHOST.to_string()) - .arg("-F") - // Useful for debugging - // .arg("-d") - // .arg("5") - .arg("-p") - .arg(port.to_string()); - - if let Some((cert_path, key_path)) = ssl { - let postgres_cert_path = data_dir.join("server.crt"); - let postgres_key_path = data_dir.join("server.key"); - std::fs::copy(cert_path, &postgres_cert_path)?; - std::fs::copy(key_path, &postgres_key_path)?; - // Set permissions for the certificate and key files - std::fs::set_permissions(&postgres_cert_path, std::fs::Permissions::from_mode(0o600))?; - std::fs::set_permissions(&postgres_key_path, std::fs::Permissions::from_mode(0o600))?; - - // Edit pg_hba.conf to change all "host" line prefixes to "hostssl" - let pg_hba_path = data_dir.join("pg_hba.conf"); - let content = std::fs::read_to_string(&pg_hba_path)?; - let modified_content = content - .lines() - .filter(|line| !line.starts_with("#") && !line.is_empty()) - .map(|line| { - if line.trim_start().starts_with("host") { - line.replacen("host", "hostssl", 1) - } else { - line.to_string() - } - }) - .collect::>() - .join("\n"); - eprintln!("pg_hba.conf:\n==========\n{modified_content}\n=========="); - std::fs::write(&pg_hba_path, modified_content)?; - - command.arg("-l"); - } - - let mut child = command.spawn()?; - - let stdout_reader = BufReader::new(child.stdout.take().expect("Failed to capture stdout")); - let _ = StdioReader::spawn(stdout_reader, "stdout"); - let stderr_reader = BufReader::new(child.stderr.take().expect("Failed to capture stderr")); - let stderr_reader = StdioReader::spawn(stderr_reader, "stderr"); - - let start_time = Instant::now(); - - let mut tcp_socket: Option = None; - let mut unix_socket: Option = None; - - let unix_socket_path = get_unix_socket_path(socket_path, port); - let tcp_socket_addr = std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, port)); - let mut db_ready = false; - - while start_time.elapsed() < STARTUP_TIMEOUT_DURATION { - std::thread::sleep(HOT_LOOP_INTERVAL); - match child.try_wait() { - Ok(Some(status)) => { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - format!("PostgreSQL exited with status: {}", status), - )) - } - Err(e) => return Err(e), - _ => {} - } - if !db_ready && stderr_reader.contains("database system is ready to accept connections") { - eprintln!("Database is ready"); - db_ready = true; - } else { - continue; - } - if unix_socket.is_none() { - unix_socket = std::os::unix::net::UnixStream::connect(&unix_socket_path).ok(); - } - if tcp_socket.is_none() { - tcp_socket = std::net::TcpStream::connect(tcp_socket_addr).ok(); - } - if unix_socket.is_some() && tcp_socket.is_some() { - break; - } - } - - if unix_socket.is_some() && tcp_socket.is_some() { - return Ok(child); - } - - // Print status for TCP/unix sockets - if let Some(tcp) = &tcp_socket { - eprintln!( - "TCP socket at {tcp_socket_addr:?} bound successfully on {}", - tcp.local_addr()? - ); - } else { - eprintln!("TCP socket at {tcp_socket_addr:?} binding failed"); - } - - if unix_socket.is_some() { - eprintln!("Unix socket at {unix_socket_path:?} connected successfully"); - } else { - eprintln!("Unix socket at {unix_socket_path:?} connection failed"); - } - - Err(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "PostgreSQL failed to start within 30 seconds", - )) -} - -fn test_data_dir() -> std::path::PathBuf { - Path::new("../../../tests") - .canonicalize() - .expect("Failed to canonicalize tests directory path") -} - -fn postgres_bin_dir() -> std::io::Result { - Path::new("../../../build/postgres/install/bin").canonicalize() -} - -fn get_unix_socket_path(socket_path: &Path, port: u16) -> PathBuf { - socket_path.join(format!(".s.PGSQL.{}", port)) -} - -#[derive(Debug, Clone, Copy)] -enum Mode { - Tcp, - TcpSsl, - Unix, -} - -fn create_ssl_client() -> Result> { - let ssl_context = SslContext::builder(SslMethod::tls_client())?.build(); - let mut ssl = Ssl::new(&ssl_context)?; - ssl.set_connect_state(); - Ok(ssl) -} -struct PostgresProcess { - child: std::process::Child, - socket_address: ResolvedTarget, - #[allow(unused)] - temp_dir: TempDir, -} - -impl Drop for PostgresProcess { - fn drop(&mut self) { - let _ = self.child.kill(); - } -} - -fn setup_postgres( - auth: AuthType, - mode: Mode, -) -> Result, Box> { - let Ok(bindir) = postgres_bin_dir() else { - println!("Skipping test: postgres bin dir not found"); - return Ok(None); - }; - - let initdb = bindir.join("initdb"); - let postgres = bindir.join("postgres"); - - if !initdb.exists() || !postgres.exists() { - println!("Skipping test: initdb or postgres not found"); - return Ok(None); - } - - let temp_dir = TempDir::new()?; - let port = EphemeralPort::allocate()?; - let data_dir = temp_dir.path().join("data"); - - init_postgres(&initdb, &data_dir, auth)?; - let ssl_key = match mode { - Mode::TcpSsl => { - let certs_dir = test_data_dir().join("certs"); - let cert = certs_dir.join("server.cert.pem"); - let key = certs_dir.join("server.key.pem"); - Some((cert, key)) - } - _ => None, - }; - - let port = port.take(); - let child = run_postgres(&postgres, &data_dir, &data_dir, ssl_key, port)?; - - let socket_address = match mode { - Mode::Unix => ResolvedTarget::to_addrs_sync(&Host( - HostType::Path(data_dir.to_string_lossy().to_string()), - port, - ))? - .remove(0), - Mode::Tcp | Mode::TcpSsl => { - ResolvedTarget::SocketAddr(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port)) - } - }; - - Ok(Some(PostgresProcess { - child, - socket_address, - temp_dir, - })) -} - #[rstest] #[tokio::test] async fn test_auth_real( @@ -391,7 +35,7 @@ async fn test_auth_real( server_settings: Default::default(), }; - let client = postgres_process.socket_address.connect().await?; + let client = address(&postgres_process.socket_address).connect().await?; let ssl_requirement = match mode { Mode::TcpSsl => ConnectionSslRequirement::Required, @@ -426,7 +70,7 @@ async fn test_bad_password( server_settings: Default::default(), }; - let client = postgres_process.socket_address.connect().await?; + let client = address(&postgres_process.socket_address).connect().await?; let ssl_requirement = match mode { Mode::TcpSsl => ConnectionSslRequirement::Required, @@ -458,7 +102,7 @@ async fn test_bad_username( server_settings: Default::default(), }; - let client = postgres_process.socket_address.connect().await?; + let client = address(&postgres_process.socket_address).connect().await?; let ssl_requirement = match mode { Mode::TcpSsl => ConnectionSslRequirement::Required, @@ -490,7 +134,7 @@ async fn test_bad_database( server_settings: Default::default(), }; - let client = postgres_process.socket_address.connect().await?; + let client = address(&postgres_process.socket_address).connect().await?; let ssl_requirement = match mode { Mode::TcpSsl => ConnectionSslRequirement::Required, diff --git a/tests/inplace-testing/test.sh b/tests/inplace-testing/test.sh index f7265fc1e20..d3ac8012738 100755 --- a/tests/inplace-testing/test.sh +++ b/tests/inplace-testing/test.sh @@ -151,6 +151,9 @@ if $EDGEDB query 'create empty branch asdf'; then fi $EDGEDB query 'configure instance reset force_database_error' stop_server +if [ "$SAVE_TARBALLS" = 1 ]; then + tar cf "$DIR"-cooked2.tar "$DIR" +fi # Test! diff --git a/tests/test_edgeql_sys.py b/tests/test_edgeql_sys.py index 5909661e757..319fe92d7ae 100644 --- a/tests/test_edgeql_sys.py +++ b/tests/test_edgeql_sys.py @@ -38,12 +38,13 @@ async def _configure_track(self, option: str): async def _bad_query_for_stats(self): raise NotImplementedError - async def _test_sys_query_stats(self): + def _before_test_sys_query_stats(self): if self.backend_dsn: self.skipTest( "can't run query stats test when extension isn't present" ) + async def _test_sys_query_stats(self): stats_query = f''' with stats := ( select @@ -177,7 +178,15 @@ async def _bad_query_for_stats(self): await self.con.query(f'select {self.stats_magic_word}_NoSuchType') async def test_edgeql_sys_query_stats(self): - await self._test_sys_query_stats() + self._before_test_sys_query_stats() + async with tb.start_edgedb_server() as sd: + old_con = self.con + self.con = await sd.connect() + try: + await self._test_sys_query_stats() + finally: + await self.con.aclose() + self.con = old_con class TestSQLSys(tb.SQLQueryTestCase, TestQueryStatsMixin): @@ -215,4 +224,14 @@ async def _bad_query_for_stats(self): ) async def test_sql_sys_query_stats(self): - await self._test_sys_query_stats() + self._before_test_sys_query_stats() + async with tb.start_edgedb_server() as sd: + old_cons = self.con, self.scon + self.con = await sd.connect() + self.scon = await sd.connect_pg() + try: + await self._test_sys_query_stats() + finally: + await self.scon.close() + await self.con.aclose() + self.con, self.scon = old_cons diff --git a/tests/test_sql_query.py b/tests/test_sql_query.py index 8bf94846b72..bf799d40f5e 100644 --- a/tests/test_sql_query.py +++ b/tests/test_sql_query.py @@ -3015,6 +3015,57 @@ async def are_policies_applied() -> bool: # setting cleanup not needed, since with end with the None, None + async def test_sql_query_set_05(self): + # IntervalStyle + + await self.scon.execute('SET IntervalStyle TO ISO_8601;') + [[res]] = await self.squery_values( + "SELECT '2 years 15 months 100 weeks 99 hours'::interval::text;" + ) + self.assertEqual(res, 'P3Y3M700DT99H') + + await self.scon.execute('SET IntervalStyle TO postgres_verbose;') + [[res]] = await self.squery_values( + "SELECT '2 years 15 months 100 weeks 99 hours'::interval::text;" + ) + self.assertEqual(res, '@ 3 years 3 mons 700 days 99 hours') + + await self.scon.execute('SET IntervalStyle TO sql_standard;') + [[res]] = await self.squery_values( + "SELECT '2 years 15 months 100 weeks 99 hours'::interval::text;" + ) + self.assertEqual(res, '+3-3 +700 +99:00:00') + + async def test_sql_query_set_06(self): + # bytea_output + + await self.scon.execute('SET bytea_output TO hex') + [[res]] = await self.squery_values( + "SELECT '\\x01abcdef01'::bytea::text" + ) + self.assertEqual(res, r'\x01abcdef01') + + await self.scon.execute('SET bytea_output TO escape') + [[res]] = await self.squery_values( + "SELECT '\\x01abcdef01'::bytea::text" + ) + self.assertEqual(res, r'\001\253\315\357\001') + + async def test_sql_query_set_07(self): + # enable_memoize + + await self.scon.execute('SET enable_memoize TO ye') + [[res]] = await self.squery_values( + "SELECT 'hello'" + ) + self.assertEqual(res, 'hello') + + await self.scon.execute('SET enable_memoize TO off') + [[res]] = await self.squery_values( + "SELECT 'hello'" + ) + self.assertEqual(res, 'hello') + @test.skip( 'blocking the connection causes other tests which trigger a ' 'PostgreSQL error to encounter a InternalServerError and close '