Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional support of protocol 3.0 #552

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions gel/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class QueryContext(typing.NamedTuple):
retry_options: typing.Optional[options.RetryOptions]
state: typing.Optional[options.State]
warning_handler: options.WarningHandler
annotations: typing.Dict[str, str]

def lower(
self, *, allow_capabilities: enums.Capability
Expand All @@ -83,6 +84,7 @@ def lower(
required_one=self.query_options.required_one,
allow_capabilities=allow_capabilities,
state=self.state.as_dict() if self.state else None,
annotations=self.annotations,
)


Expand All @@ -91,6 +93,7 @@ class ExecuteContext(typing.NamedTuple):
cache: QueryCache
state: typing.Optional[options.State]
warning_handler: options.WarningHandler
annotations: typing.Dict[str, str]

def lower(
self, *, allow_capabilities: enums.Capability
Expand All @@ -105,6 +108,7 @@ def lower(
output_format=protocol.OutputFormat.NONE,
allow_capabilities=allow_capabilities,
state=self.state.as_dict() if self.state else None,
annotations=self.annotations,
)


Expand Down Expand Up @@ -193,6 +197,9 @@ def _get_state(self) -> options.State:
def _get_warning_handler(self) -> options.WarningHandler:
...

def _get_annotations(self) -> typing.Dict[str, str]:
return {}


class ReadOnlyExecutor(BaseReadOnlyExecutor):
"""Subclasses can execute *at least* read-only queries"""
Expand All @@ -211,6 +218,7 @@ def query(self, query: str, *args, **kwargs) -> list:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_single(
Expand All @@ -223,6 +231,7 @@ def query_single(
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_required_single(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -233,6 +242,7 @@ def query_required_single(self, query: str, *args, **kwargs) -> typing.Any:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -243,6 +253,7 @@ def query_json(self, query: str, *args, **kwargs) -> str:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -253,6 +264,7 @@ def query_single_json(self, query: str, *args, **kwargs) -> str:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_required_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -263,6 +275,7 @@ def query_required_single_json(self, query: str, *args, **kwargs) -> str:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -278,6 +291,7 @@ def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

@abc.abstractmethod
Expand All @@ -290,6 +304,7 @@ def execute(self, commands: str, *args, **kwargs) -> None:
cache=self._get_query_cache(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

def execute_sql(self, commands: str, *args, **kwargs) -> None:
Expand All @@ -303,6 +318,7 @@ def execute_sql(self, commands: str, *args, **kwargs) -> None:
cache=self._get_query_cache(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))


Expand All @@ -329,6 +345,7 @@ async def query(self, query: str, *args, **kwargs) -> list:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_single(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -339,6 +356,7 @@ async def query_single(self, query: str, *args, **kwargs) -> typing.Any:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_required_single(
Expand All @@ -354,6 +372,7 @@ async def query_required_single(
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -364,6 +383,7 @@ async def query_json(self, query: str, *args, **kwargs) -> str:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -374,6 +394,7 @@ async def query_single_json(self, query: str, *args, **kwargs) -> str:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_required_single_json(
Expand All @@ -389,6 +410,7 @@ async def query_required_single_json(
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -404,6 +426,7 @@ async def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

@abc.abstractmethod
Expand All @@ -416,6 +439,7 @@ async def execute(self, commands: str, *args, **kwargs) -> None:
cache=self._get_query_cache(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))

async def execute_sql(self, commands: str, *args, **kwargs) -> None:
Expand All @@ -429,6 +453,7 @@ async def execute_sql(self, commands: str, *args, **kwargs) -> None:
cache=self._get_query_cache(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))


Expand Down
3 changes: 3 additions & 0 deletions gel/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,9 @@ def _get_state(self) -> _options.State:
def _get_warning_handler(self) -> _options.WarningHandler:
return self._options.warning_handler

def _get_annotations(self) -> typing.Dict[str, str]:
return self._options.annotations

@property
def max_concurrency(self) -> int:
"""Max number of connections in the pool."""
Expand Down
36 changes: 35 additions & 1 deletion gel/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,27 @@ def without_globals(self, *global_names):
)
return result

def with_annotation(self, name: str, value: str):
result = self._shallow_clone()
result._options = self._options.with_annotations(
self._options.annotations | {name: value}
)
return result

def without_annotation(self, name: str):
result = self._shallow_clone()
annotations = self._options.annotations.copy()
annotations.pop(name, None)
result._options = self._options.with_annotations(annotations)
return result


class _Options:
"""Internal class for storing connection options"""

__slots__ = [
'_retry_options', '_transaction_options', '_state',
'_warning_handler'
'_warning_handler', '_annotations'
]

def __init__(
Expand All @@ -428,11 +442,13 @@ def __init__(
transaction_options: TransactionOptions,
state: State,
warning_handler: WarningHandler,
annotations: typing.Dict[str, str],
):
self._retry_options = retry_options
self._transaction_options = transaction_options
self._state = state
self._warning_handler = warning_handler
self._annotations = annotations

@property
def retry_options(self):
Expand All @@ -450,12 +466,17 @@ def state(self):
def warning_handler(self):
return self._warning_handler

@property
def annotations(self):
return self._annotations

def with_retry_options(self, options: RetryOptions):
return _Options(
options,
self._transaction_options,
self._state,
self._warning_handler,
self._annotations,
)

def with_transaction_options(self, options: TransactionOptions):
Expand All @@ -464,6 +485,7 @@ def with_transaction_options(self, options: TransactionOptions):
options,
self._state,
self._warning_handler,
self._annotations,
)

def with_state(self, state: State):
Expand All @@ -472,6 +494,7 @@ def with_state(self, state: State):
self._transaction_options,
state,
self._warning_handler,
self._annotations,
)

def with_warning_handler(self, warning_handler: WarningHandler):
Expand All @@ -480,6 +503,16 @@ def with_warning_handler(self, warning_handler: WarningHandler):
self._transaction_options,
self._state,
warning_handler,
self._annotations,
)

def with_annotations(self, annotations: typing.Dict[str, str]):
return _Options(
self._retry_options,
self._transaction_options,
self._state,
self._warning_handler,
annotations,
)

@classmethod
Expand All @@ -489,4 +522,5 @@ def defaults(cls):
TransactionOptions.defaults(),
State.defaults(),
log_warnings,
{},
)
2 changes: 2 additions & 0 deletions gel/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ cdef class ExecuteContext:
bint inline_typeids
uint64_t allow_capabilities
object state
object annotations

# Contextual variables
readonly bytes cardinality
Expand Down Expand Up @@ -151,6 +152,7 @@ cdef class SansIOProtocol:
cdef inline ignore_headers(self)
cdef inline dict read_headers(self)
cdef dict parse_error_headers(self)
cdef write_annotations(self, ExecuteContext ctx, WriteBuffer buf)

cdef parse_error_message(self)

Expand Down
29 changes: 24 additions & 5 deletions gel/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ cdef class ExecuteContext:
inline_typeids: bool = False,
allow_capabilities: enums.Capability = enums.Capability.ALL,
state: typing.Optional[dict] = None,
annotations: typing.Optional[dict[str, str]] = None,
):
self.query = query
self.args = args
Expand All @@ -129,6 +130,7 @@ cdef class ExecuteContext:
self.in_dc = self.out_dc = None
self.capabilities = 0
self.warnings = ()
self.annotations = annotations

cdef inline bint has_na_cardinality(self):
return self.cardinality == CARDINALITY_NOT_APPLICABLE
Expand Down Expand Up @@ -250,6 +252,18 @@ cdef class SansIOProtocol:

return headers

cdef write_annotations(self, ExecuteContext ctx, WriteBuffer buf):
num_annos = len(ctx.annotations) if ctx.annotations is not None else 0
if self.protocol_version >= (3, 0) and num_annos > 0:
if num_annos >= 1 << 16:
raise errors.InvalidArgumentError("too many annotations")
buf.write_int16(num_annos)
for key, value in ctx.annotations.items():
buf.write_len_prefixed_utf8(key)
buf.write_len_prefixed_utf8(value)
else:
buf.write_int16(0) # no annotations

cdef ensure_connected(self):
if self.cancelled:
raise errors.ClientConnectionClosedError(
Expand Down Expand Up @@ -297,7 +311,7 @@ cdef class SansIOProtocol:
raise RuntimeError('not connected')

buf = WriteBuffer.new_message(PREPARE_MSG)
buf.write_int16(0) # no headers
self.write_annotations(ctx, buf)

params = self.encode_parse_params(ctx)

Expand Down Expand Up @@ -359,7 +373,7 @@ cdef class SansIOProtocol:
params = self.encode_parse_params(ctx)

buf = WriteBuffer.new_message(EXECUTE_MSG)
buf.write_int16(0) # no headers
self.write_annotations(ctx, buf)

buf.write_buffer(params)

Expand Down Expand Up @@ -525,8 +539,13 @@ cdef class SansIOProtocol:
self.reset_status()

buf = WriteBuffer.new_message(DUMP_MSG)
buf.write_int16(0) # no headers
buf.end_message()
if self.protocol_version >= (3, 0):
buf.write_int16(0) # no annotations
buf.write_int64(0) # flags
buf.end_message()
else:
buf.write_int16(0) # no headers
buf.end_message()
buf.write_bytes(SYNC_MESSAGE)
self.write(buf)

Expand Down Expand Up @@ -627,7 +646,7 @@ cdef class SansIOProtocol:
self.reset_status()

buf = WriteBuffer.new_message(RESTORE_MSG)
buf.write_int16(0) # no headers
buf.write_int16(0) # no attributes
buf.write_int16(1) # -j level
buf.write_bytes(header)
buf.end_message()
Expand Down
Loading
Loading