Skip to content

Commit

Permalink
Additional support of protocol 3.0 (#552)
Browse files Browse the repository at this point in the history
* Add with/out_annotation()
* Use new Dump message for protocol >= 3.0
  • Loading branch information
fantix authored Nov 27, 2024
1 parent 2757156 commit a64a177
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 6 deletions.
25 changes: 25 additions & 0 deletions gel/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class QueryContext(typing.NamedTuple):
retry_options: typing.Optional[options.RetryOptions]
state: typing.Optional[options.State]
warning_handler: options.WarningHandler
annotations: typing.Dict[str, str]

def lower(
self, *, allow_capabilities: enums.Capability
Expand All @@ -83,6 +84,7 @@ def lower(
required_one=self.query_options.required_one,
allow_capabilities=allow_capabilities,
state=self.state.as_dict() if self.state else None,
annotations=self.annotations,
)


Expand All @@ -91,6 +93,7 @@ class ExecuteContext(typing.NamedTuple):
cache: QueryCache
state: typing.Optional[options.State]
warning_handler: options.WarningHandler
annotations: typing.Dict[str, str]

def lower(
self, *, allow_capabilities: enums.Capability
Expand All @@ -105,6 +108,7 @@ def lower(
output_format=protocol.OutputFormat.NONE,
allow_capabilities=allow_capabilities,
state=self.state.as_dict() if self.state else None,
annotations=self.annotations,
)


Expand Down Expand Up @@ -193,6 +197,9 @@ def _get_state(self) -> options.State:
def _get_warning_handler(self) -> options.WarningHandler:
...

def _get_annotations(self) -> typing.Dict[str, str]:
return {}


class ReadOnlyExecutor(BaseReadOnlyExecutor):
"""Subclasses can execute *at least* read-only queries"""
Expand All @@ -211,6 +218,7 @@ def query(self, query: str, *args, **kwargs) -> list:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_single(
Expand All @@ -223,6 +231,7 @@ def query_single(
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_required_single(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -233,6 +242,7 @@ def query_required_single(self, query: str, *args, **kwargs) -> typing.Any:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -243,6 +253,7 @@ def query_json(self, query: str, *args, **kwargs) -> str:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -253,6 +264,7 @@ def query_single_json(self, query: str, *args, **kwargs) -> str:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_required_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -263,6 +275,7 @@ def query_required_single_json(self, query: str, *args, **kwargs) -> str:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -278,6 +291,7 @@ def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

@abc.abstractmethod
Expand All @@ -290,6 +304,7 @@ def execute(self, commands: str, *args, **kwargs) -> None:
cache=self._get_query_cache(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def execute_sql(self, commands: str, *args, **kwargs) -> None:
Expand All @@ -303,6 +318,7 @@ def execute_sql(self, commands: str, *args, **kwargs) -> None:
cache=self._get_query_cache(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))


Expand All @@ -329,6 +345,7 @@ async def query(self, query: str, *args, **kwargs) -> list:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_single(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -339,6 +356,7 @@ async def query_single(self, query: str, *args, **kwargs) -> typing.Any:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_required_single(
Expand All @@ -354,6 +372,7 @@ async def query_required_single(
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -364,6 +383,7 @@ async def query_json(self, query: str, *args, **kwargs) -> str:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -374,6 +394,7 @@ async def query_single_json(self, query: str, *args, **kwargs) -> str:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_required_single_json(
Expand All @@ -389,6 +410,7 @@ async def query_required_single_json(
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -404,6 +426,7 @@ async def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

@abc.abstractmethod
Expand All @@ -416,6 +439,7 @@ async def execute(self, commands: str, *args, **kwargs) -> None:
cache=self._get_query_cache(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def execute_sql(self, commands: str, *args, **kwargs) -> None:
Expand All @@ -429,6 +453,7 @@ async def execute_sql(self, commands: str, *args, **kwargs) -> None:
cache=self._get_query_cache(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))


Expand Down
3 changes: 3 additions & 0 deletions gel/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,9 @@ def _get_state(self) -> _options.State:
def _get_warning_handler(self) -> _options.WarningHandler:
return self._options.warning_handler

def _get_annotations(self) -> typing.Dict[str, str]:
return self._options.annotations

@property
def max_concurrency(self) -> int:
"""Max number of connections in the pool."""
Expand Down
36 changes: 35 additions & 1 deletion gel/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,27 @@ def without_globals(self, *global_names):
)
return result

def with_annotation(self, name: str, value: str):
result = self._shallow_clone()
result._options = self._options.with_annotations(
self._options.annotations | {name: value}
)
return result

def without_annotation(self, name: str):
result = self._shallow_clone()
annotations = self._options.annotations.copy()
annotations.pop(name, None)
result._options = self._options.with_annotations(annotations)
return result


class _Options:
"""Internal class for storing connection options"""

__slots__ = [
'_retry_options', '_transaction_options', '_state',
'_warning_handler'
'_warning_handler', '_annotations'
]

def __init__(
Expand All @@ -428,11 +442,13 @@ def __init__(
transaction_options: TransactionOptions,
state: State,
warning_handler: WarningHandler,
annotations: typing.Dict[str, str],
):
self._retry_options = retry_options
self._transaction_options = transaction_options
self._state = state
self._warning_handler = warning_handler
self._annotations = annotations

@property
def retry_options(self):
Expand All @@ -450,12 +466,17 @@ def state(self):
def warning_handler(self):
return self._warning_handler

@property
def annotations(self):
return self._annotations

def with_retry_options(self, options: RetryOptions):
return _Options(
options,
self._transaction_options,
self._state,
self._warning_handler,
self._annotations,
)

def with_transaction_options(self, options: TransactionOptions):
Expand All @@ -464,6 +485,7 @@ def with_transaction_options(self, options: TransactionOptions):
options,
self._state,
self._warning_handler,
self._annotations,
)

def with_state(self, state: State):
Expand All @@ -472,6 +494,7 @@ def with_state(self, state: State):
self._transaction_options,
state,
self._warning_handler,
self._annotations,
)

def with_warning_handler(self, warning_handler: WarningHandler):
Expand All @@ -480,6 +503,16 @@ def with_warning_handler(self, warning_handler: WarningHandler):
self._transaction_options,
self._state,
warning_handler,
self._annotations,
)

def with_annotations(self, annotations: typing.Dict[str, str]):
return _Options(
self._retry_options,
self._transaction_options,
self._state,
self._warning_handler,
annotations,
)

@classmethod
Expand All @@ -489,4 +522,5 @@ def defaults(cls):
TransactionOptions.defaults(),
State.defaults(),
log_warnings,
{},
)
2 changes: 2 additions & 0 deletions gel/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ cdef class ExecuteContext:
bint inline_typeids
uint64_t allow_capabilities
object state
object annotations

# Contextual variables
readonly bytes cardinality
Expand Down Expand Up @@ -151,6 +152,7 @@ cdef class SansIOProtocol:
cdef inline ignore_headers(self)
cdef inline dict read_headers(self)
cdef dict parse_error_headers(self)
cdef write_annotations(self, ExecuteContext ctx, WriteBuffer buf)

cdef parse_error_message(self)

Expand Down
29 changes: 24 additions & 5 deletions gel/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ cdef class ExecuteContext:
inline_typeids: bool = False,
allow_capabilities: enums.Capability = enums.Capability.ALL,
state: typing.Optional[dict] = None,
annotations: typing.Optional[dict[str, str]] = None,
):
self.query = query
self.args = args
Expand All @@ -129,6 +130,7 @@ cdef class ExecuteContext:
self.in_dc = self.out_dc = None
self.capabilities = 0
self.warnings = ()
self.annotations = annotations

cdef inline bint has_na_cardinality(self):
return self.cardinality == CARDINALITY_NOT_APPLICABLE
Expand Down Expand Up @@ -250,6 +252,18 @@ cdef class SansIOProtocol:

return headers

cdef write_annotations(self, ExecuteContext ctx, WriteBuffer buf):
num_annos = len(ctx.annotations) if ctx.annotations is not None else 0
if self.protocol_version >= (3, 0) and num_annos > 0:
if num_annos >= 1 << 16:
raise errors.InvalidArgumentError("too many annotations")
buf.write_int16(num_annos)
for key, value in ctx.annotations.items():
buf.write_len_prefixed_utf8(key)
buf.write_len_prefixed_utf8(value)
else:
buf.write_int16(0) # no annotations

cdef ensure_connected(self):
if self.cancelled:
raise errors.ClientConnectionClosedError(
Expand Down Expand Up @@ -297,7 +311,7 @@ cdef class SansIOProtocol:
raise RuntimeError('not connected')

buf = WriteBuffer.new_message(PREPARE_MSG)
buf.write_int16(0) # no headers
self.write_annotations(ctx, buf)

params = self.encode_parse_params(ctx)

Expand Down Expand Up @@ -359,7 +373,7 @@ cdef class SansIOProtocol:
params = self.encode_parse_params(ctx)

buf = WriteBuffer.new_message(EXECUTE_MSG)
buf.write_int16(0) # no headers
self.write_annotations(ctx, buf)

buf.write_buffer(params)

Expand Down Expand Up @@ -525,8 +539,13 @@ cdef class SansIOProtocol:
self.reset_status()

buf = WriteBuffer.new_message(DUMP_MSG)
buf.write_int16(0) # no headers
buf.end_message()
if self.protocol_version >= (3, 0):
buf.write_int16(0) # no annotations
buf.write_int64(0) # flags
buf.end_message()
else:
buf.write_int16(0) # no headers
buf.end_message()
buf.write_bytes(SYNC_MESSAGE)
self.write(buf)

Expand Down Expand Up @@ -627,7 +646,7 @@ cdef class SansIOProtocol:
self.reset_status()

buf = WriteBuffer.new_message(RESTORE_MSG)
buf.write_int16(0) # no headers
buf.write_int16(0) # no attributes
buf.write_int16(1) # -j level
buf.write_bytes(header)
buf.end_message()
Expand Down
Loading

0 comments on commit a64a177

Please sign in to comment.