From 31fd3ca3cca8b7618d7549f4b44145657189eb39 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 27 Nov 2024 11:21:22 -0500 Subject: [PATCH 1/2] Add with/out_annotation() --- gel/abstract.py | 25 +++++++++++++++++++++++++ gel/base_client.py | 3 +++ gel/options.py | 36 +++++++++++++++++++++++++++++++++++- gel/protocol/protocol.pxd | 2 ++ gel/protocol/protocol.pyx | 18 ++++++++++++++++-- gel/transaction.py | 6 ++++++ tests/test_async_query.py | 1 + tests/test_sync_query.py | 1 + 8 files changed, 89 insertions(+), 3 deletions(-) diff --git a/gel/abstract.py b/gel/abstract.py index 32aa0e1f..abffbd23 100644 --- a/gel/abstract.py +++ b/gel/abstract.py @@ -67,6 +67,7 @@ class QueryContext(typing.NamedTuple): retry_options: typing.Optional[options.RetryOptions] state: typing.Optional[options.State] warning_handler: options.WarningHandler + annotations: typing.Dict[str, str] def lower( self, *, allow_capabilities: enums.Capability @@ -83,6 +84,7 @@ def lower( required_one=self.query_options.required_one, allow_capabilities=allow_capabilities, state=self.state.as_dict() if self.state else None, + annotations=self.annotations, ) @@ -91,6 +93,7 @@ class ExecuteContext(typing.NamedTuple): cache: QueryCache state: typing.Optional[options.State] warning_handler: options.WarningHandler + annotations: typing.Dict[str, str] def lower( self, *, allow_capabilities: enums.Capability @@ -105,6 +108,7 @@ def lower( output_format=protocol.OutputFormat.NONE, allow_capabilities=allow_capabilities, state=self.state.as_dict() if self.state else None, + annotations=self.annotations, ) @@ -193,6 +197,9 @@ def _get_state(self) -> options.State: def _get_warning_handler(self) -> options.WarningHandler: ... + def _get_annotations(self) -> typing.Dict[str, str]: + return {} + class ReadOnlyExecutor(BaseReadOnlyExecutor): """Subclasses can execute *at least* read-only queries""" @@ -211,6 +218,7 @@ def query(self, query: str, *args, **kwargs) -> list: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) def query_single( @@ -223,6 +231,7 @@ def query_single( retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: @@ -233,6 +242,7 @@ def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) def query_json(self, query: str, *args, **kwargs) -> str: @@ -243,6 +253,7 @@ def query_json(self, query: str, *args, **kwargs) -> str: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) def query_single_json(self, query: str, *args, **kwargs) -> str: @@ -253,6 +264,7 @@ def query_single_json(self, query: str, *args, **kwargs) -> str: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) def query_required_single_json(self, query: str, *args, **kwargs) -> str: @@ -263,6 +275,7 @@ def query_required_single_json(self, query: str, *args, **kwargs) -> str: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) def query_sql(self, query: str, *args, **kwargs) -> typing.Any: @@ -278,6 +291,7 @@ def query_sql(self, query: str, *args, **kwargs) -> typing.Any: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) @abc.abstractmethod @@ -290,6 +304,7 @@ def execute(self, commands: str, *args, **kwargs) -> None: cache=self._get_query_cache(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) def execute_sql(self, commands: str, *args, **kwargs) -> None: @@ -303,6 +318,7 @@ def execute_sql(self, commands: str, *args, **kwargs) -> None: cache=self._get_query_cache(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) @@ -329,6 +345,7 @@ async def query(self, query: str, *args, **kwargs) -> list: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) async def query_single(self, query: str, *args, **kwargs) -> typing.Any: @@ -339,6 +356,7 @@ async def query_single(self, query: str, *args, **kwargs) -> typing.Any: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) async def query_required_single( @@ -354,6 +372,7 @@ async def query_required_single( retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) async def query_json(self, query: str, *args, **kwargs) -> str: @@ -364,6 +383,7 @@ async def query_json(self, query: str, *args, **kwargs) -> str: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) async def query_single_json(self, query: str, *args, **kwargs) -> str: @@ -374,6 +394,7 @@ async def query_single_json(self, query: str, *args, **kwargs) -> str: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) async def query_required_single_json( @@ -389,6 +410,7 @@ async def query_required_single_json( retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) async def query_sql(self, query: str, *args, **kwargs) -> typing.Any: @@ -404,6 +426,7 @@ async def query_sql(self, query: str, *args, **kwargs) -> typing.Any: retry_options=self._get_retry_options(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) @abc.abstractmethod @@ -416,6 +439,7 @@ async def execute(self, commands: str, *args, **kwargs) -> None: cache=self._get_query_cache(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) async def execute_sql(self, commands: str, *args, **kwargs) -> None: @@ -429,6 +453,7 @@ async def execute_sql(self, commands: str, *args, **kwargs) -> None: cache=self._get_query_cache(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) diff --git a/gel/base_client.py b/gel/base_client.py index 5bc1a86a..83caefac 100644 --- a/gel/base_client.py +++ b/gel/base_client.py @@ -694,6 +694,9 @@ def _get_state(self) -> _options.State: def _get_warning_handler(self) -> _options.WarningHandler: return self._options.warning_handler + def _get_annotations(self) -> typing.Dict[str, str]: + return self._options.annotations + @property def max_concurrency(self) -> int: """Max number of connections in the pool.""" diff --git a/gel/options.py b/gel/options.py index 6f83d421..239cb4c1 100644 --- a/gel/options.py +++ b/gel/options.py @@ -413,13 +413,27 @@ def without_globals(self, *global_names): ) return result + def with_annotation(self, name: str, value: str): + result = self._shallow_clone() + result._options = self._options.with_annotations( + self._options.annotations | {name: value} + ) + return result + + def without_annotation(self, name: str): + result = self._shallow_clone() + annotations = self._options.annotations.copy() + annotations.pop(name, None) + result._options = self._options.with_annotations(annotations) + return result + class _Options: """Internal class for storing connection options""" __slots__ = [ '_retry_options', '_transaction_options', '_state', - '_warning_handler' + '_warning_handler', '_annotations' ] def __init__( @@ -428,11 +442,13 @@ def __init__( transaction_options: TransactionOptions, state: State, warning_handler: WarningHandler, + annotations: typing.Dict[str, str], ): self._retry_options = retry_options self._transaction_options = transaction_options self._state = state self._warning_handler = warning_handler + self._annotations = annotations @property def retry_options(self): @@ -450,12 +466,17 @@ def state(self): def warning_handler(self): return self._warning_handler + @property + def annotations(self): + return self._annotations + def with_retry_options(self, options: RetryOptions): return _Options( options, self._transaction_options, self._state, self._warning_handler, + self._annotations, ) def with_transaction_options(self, options: TransactionOptions): @@ -464,6 +485,7 @@ def with_transaction_options(self, options: TransactionOptions): options, self._state, self._warning_handler, + self._annotations, ) def with_state(self, state: State): @@ -472,6 +494,7 @@ def with_state(self, state: State): self._transaction_options, state, self._warning_handler, + self._annotations, ) def with_warning_handler(self, warning_handler: WarningHandler): @@ -480,6 +503,16 @@ def with_warning_handler(self, warning_handler: WarningHandler): self._transaction_options, self._state, warning_handler, + self._annotations, + ) + + def with_annotations(self, annotations: typing.Dict[str, str]): + return _Options( + self._retry_options, + self._transaction_options, + self._state, + self._warning_handler, + annotations, ) @classmethod @@ -489,4 +522,5 @@ def defaults(cls): TransactionOptions.defaults(), State.defaults(), log_warnings, + {}, ) diff --git a/gel/protocol/protocol.pxd b/gel/protocol/protocol.pxd index 7c5f943f..97eb4ffc 100644 --- a/gel/protocol/protocol.pxd +++ b/gel/protocol/protocol.pxd @@ -89,6 +89,7 @@ cdef class ExecuteContext: bint inline_typeids uint64_t allow_capabilities object state + object annotations # Contextual variables readonly bytes cardinality @@ -151,6 +152,7 @@ cdef class SansIOProtocol: cdef inline ignore_headers(self) cdef inline dict read_headers(self) cdef dict parse_error_headers(self) + cdef write_annotations(self, ExecuteContext ctx, WriteBuffer buf) cdef parse_error_message(self) diff --git a/gel/protocol/protocol.pyx b/gel/protocol/protocol.pyx index e22efef7..47936cf6 100644 --- a/gel/protocol/protocol.pyx +++ b/gel/protocol/protocol.pyx @@ -109,6 +109,7 @@ cdef class ExecuteContext: inline_typeids: bool = False, allow_capabilities: enums.Capability = enums.Capability.ALL, state: typing.Optional[dict] = None, + annotations: typing.Optional[dict[str, str]] = None, ): self.query = query self.args = args @@ -129,6 +130,7 @@ cdef class ExecuteContext: self.in_dc = self.out_dc = None self.capabilities = 0 self.warnings = () + self.annotations = annotations cdef inline bint has_na_cardinality(self): return self.cardinality == CARDINALITY_NOT_APPLICABLE @@ -250,6 +252,18 @@ cdef class SansIOProtocol: return headers + cdef write_annotations(self, ExecuteContext ctx, WriteBuffer buf): + num_annos = len(ctx.annotations) if ctx.annotations is not None else 0 + if self.protocol_version >= (3, 0) and num_annos > 0: + if num_annos >= 1 << 16: + raise errors.InvalidArgumentError("too many annotations") + buf.write_int16(num_annos) + for key, value in ctx.annotations.items(): + buf.write_len_prefixed_utf8(key) + buf.write_len_prefixed_utf8(value) + else: + buf.write_int16(0) # no annotations + cdef ensure_connected(self): if self.cancelled: raise errors.ClientConnectionClosedError( @@ -297,7 +311,7 @@ cdef class SansIOProtocol: raise RuntimeError('not connected') buf = WriteBuffer.new_message(PREPARE_MSG) - buf.write_int16(0) # no headers + self.write_annotations(ctx, buf) params = self.encode_parse_params(ctx) @@ -359,7 +373,7 @@ cdef class SansIOProtocol: params = self.encode_parse_params(ctx) buf = WriteBuffer.new_message(EXECUTE_MSG) - buf.write_int16(0) # no headers + self.write_annotations(ctx, buf) buf.write_buffer(params) diff --git a/gel/transaction.py b/gel/transaction.py index 8f3ab29f..2e37d3ac 100644 --- a/gel/transaction.py +++ b/gel/transaction.py @@ -17,6 +17,8 @@ # +import typing + import enum from . import abstract @@ -188,6 +190,9 @@ def _get_state(self) -> options.State: def _get_warning_handler(self) -> options.WarningHandler: return self._client._get_warning_handler() + def _get_annotations(self) -> typing.Dict[str, str]: + return self._client._get_annotations() + async def _query(self, query_context: abstract.QueryContext): await self._ensure_transaction() return await self._connection.raw_query(query_context) @@ -202,6 +207,7 @@ async def _privileged_execute(self, query: str) -> None: cache=self._get_query_cache(), state=self._get_state(), warning_handler=self._get_warning_handler(), + annotations=self._get_annotations(), )) diff --git a/tests/test_async_query.py b/tests/test_async_query.py index 851afd02..602346ad 100644 --- a/tests/test_async_query.py +++ b/tests/test_async_query.py @@ -1036,6 +1036,7 @@ async def test_json_elements(self): retry_options=None, state=None, warning_handler=lambda _ex, _: None, + annotations={}, ) ) self.assertEqual( diff --git a/tests/test_sync_query.py b/tests/test_sync_query.py index f6d27f60..1a098e44 100644 --- a/tests/test_sync_query.py +++ b/tests/test_sync_query.py @@ -793,6 +793,7 @@ def test_json_elements(self): retry_options=None, state=None, warning_handler=lambda _ex, _: None, + annotations={}, ) ) ) From df0002868560c8a91b7bd83cb83e84f21dadd4dd Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 27 Nov 2024 11:25:24 -0500 Subject: [PATCH 2/2] Use new Dump message for protocol >= 3.0 --- gel/protocol/protocol.pyx | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/gel/protocol/protocol.pyx b/gel/protocol/protocol.pyx index 47936cf6..7b9b67a4 100644 --- a/gel/protocol/protocol.pyx +++ b/gel/protocol/protocol.pyx @@ -539,8 +539,13 @@ cdef class SansIOProtocol: self.reset_status() buf = WriteBuffer.new_message(DUMP_MSG) - buf.write_int16(0) # no headers - buf.end_message() + if self.protocol_version >= (3, 0): + buf.write_int16(0) # no annotations + buf.write_int64(0) # flags + buf.end_message() + else: + buf.write_int16(0) # no headers + buf.end_message() buf.write_bytes(SYNC_MESSAGE) self.write(buf) @@ -641,7 +646,7 @@ cdef class SansIOProtocol: self.reset_status() buf = WriteBuffer.new_message(RESTORE_MSG) - buf.write_int16(0) # no headers + buf.write_int16(0) # no attributes buf.write_int16(1) # -j level buf.write_bytes(header) buf.end_message()