diff --git a/edgedb/_testbase.py b/edgedb/_testbase.py index 5680036c..e5b59508 100644 --- a/edgedb/_testbase.py +++ b/edgedb/_testbase.py @@ -372,10 +372,16 @@ def make_test_client( database='edgedb', user='edgedb', password='test', + host=..., + port=..., connection_class=..., ): conargs = cls.get_connect_args( cluster=cluster, database=database, user=user, password=password) + if host is not ...: + conargs['host'] = host + if port is not ...: + conargs['port'] = port if connection_class is ...: connection_class = ( asyncio_client.AsyncIOConnection diff --git a/edgedb/abstract.py b/edgedb/abstract.py index 7e8f4f6a..860275c7 100644 --- a/edgedb/abstract.py +++ b/edgedb/abstract.py @@ -49,7 +49,7 @@ class QueryWithArgs(typing.NamedTuple): class QueryCache(typing.NamedTuple): codecs_registry: protocol.CodecsRegistry - query_cache: protocol.QueryCodecsCache + query_cache: protocol.LRUMapping class QueryOptions(typing.NamedTuple): @@ -65,12 +65,42 @@ class QueryContext(typing.NamedTuple): retry_options: typing.Optional[options.RetryOptions] state: typing.Optional[options.State] + def lower( + self, *, allow_capabilities: enums.Capability + ) -> protocol.ExecuteContext: + return protocol.ExecuteContext( + query=self.query.query, + args=self.query.args, + kwargs=self.query.kwargs, + reg=self.cache.codecs_registry, + qc=self.cache.query_cache, + output_format=self.query_options.output_format, + expect_one=self.query_options.expect_one, + required_one=self.query_options.required_one, + allow_capabilities=allow_capabilities, + state=self.state.as_dict() if self.state else None, + ) + class ExecuteContext(typing.NamedTuple): query: QueryWithArgs cache: QueryCache state: typing.Optional[options.State] + def lower( + self, *, allow_capabilities: enums.Capability + ) -> protocol.ExecuteContext: + return protocol.ExecuteContext( + query=self.query.query, + args=self.query.args, + kwargs=self.query.kwargs, + reg=self.cache.codecs_registry, + qc=self.cache.query_cache, + output_format=protocol.OutputFormat.NONE, + allow_capabilities=allow_capabilities, + state=self.state.as_dict() if self.state else None, + ) + @dataclasses.dataclass class DescribeContext: diff --git a/edgedb/base_client.py b/edgedb/base_client.py index d8c33767..0744211c 100644 --- a/edgedb/base_client.py +++ b/edgedb/base_client.py @@ -31,6 +31,7 @@ BaseConnection_T = typing.TypeVar('BaseConnection_T', bound='BaseConnection') +QUERY_CACHE_SIZE = 1000 class BaseConnection(metaclass=abc.ABCMeta): @@ -183,17 +184,7 @@ async def privileged_execute( ) else: await self._protocol.execute( - query=execute_context.query.query, - args=execute_context.query.args, - kwargs=execute_context.query.kwargs, - reg=execute_context.cache.codecs_registry, - qc=execute_context.cache.query_cache, - output_format=protocol.OutputFormat.NONE, - allow_capabilities=enums.Capability.ALL, - state=( - execute_context.state.as_dict() - if execute_context.state else None - ), + execute_context.lower(allow_capabilities=enums.Capability.ALL) ) def is_in_transaction(self) -> bool: @@ -211,56 +202,31 @@ async def raw_query(self, query_context: abstract.QueryContext): await self.connect() reconnect = False - capabilities = None i = 0 - args = dict( - query=query_context.query.query, - args=query_context.query.args, - kwargs=query_context.query.kwargs, - reg=query_context.cache.codecs_registry, - qc=query_context.cache.query_cache, - output_format=query_context.query_options.output_format, - expect_one=query_context.query_options.expect_one, - required_one=query_context.query_options.required_one, - ) if self._protocol.is_legacy: - args["allow_capabilities"] = enums.Capability.LEGACY_EXECUTE + allow_capabilities = enums.Capability.LEGACY_EXECUTE else: - args["allow_capabilities"] = enums.Capability.EXECUTE - if query_context.state is not None: - args["state"] = query_context.state.as_dict() + allow_capabilities = enums.Capability.EXECUTE + ctx = query_context.lower(allow_capabilities=allow_capabilities) while True: i += 1 try: if reconnect: await self.connect(single_attempt=True) if self._protocol.is_legacy: - return await self._protocol.legacy_execute_anonymous( - **args - ) + return await self._protocol.legacy_execute_anonymous(ctx) else: - return await self._protocol.query(**args) + return await self._protocol.query(ctx) except errors.EdgeDBError as e: if query_context.retry_options is None: raise if not e.has_tag(errors.SHOULD_RETRY): raise e - if capabilities is None: - cache_item = query_context.cache.query_cache.get( - query_context.query.query, - query_context.query_options.output_format, - implicit_limit=0, - inline_typenames=False, - inline_typeids=False, - expect_one=query_context.query_options.expect_one, - ) - if cache_item is not None: - _, _, _, capabilities = cache_item # A query is read-only if it has no capabilities i.e. # capabilities == 0. Read-only queries are safe to retry. # Explicit transaction conflicts as well. if ( - capabilities != 0 + ctx.capabilities != 0 and not isinstance(e, errors.TransactionConflictError) ): raise e @@ -281,17 +247,9 @@ async def _execute(self, execute_context: abstract.ExecuteContext) -> None: ) else: await self._protocol.execute( - query=execute_context.query.query, - args=execute_context.query.args, - kwargs=execute_context.query.kwargs, - reg=execute_context.cache.codecs_registry, - qc=execute_context.cache.query_cache, - output_format=protocol.OutputFormat.NONE, - allow_capabilities=enums.Capability.EXECUTE, - state=( - execute_context.state.as_dict() - if execute_context.state else None - ), + execute_context.lower( + allow_capabilities=enums.Capability.EXECUTE + ) ) async def describe( @@ -473,7 +431,7 @@ def __init__( self._connection_factory = connection_factory self._connect_args = connect_args self._codecs_registry = protocol.CodecsRegistry() - self._query_cache = protocol.QueryCodecsCache() + self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE) if max_concurrency is not None and max_concurrency <= 0: raise ValueError( @@ -570,7 +528,7 @@ def set_connect_args(self, dsn=None, **connect_kwargs): connect_kwargs["dsn"] = dsn self._connect_args = connect_kwargs self._codecs_registry = protocol.CodecsRegistry() - self._query_cache = protocol.QueryCodecsCache() + self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE) self._working_addr = None self._working_config = None self._working_params = None diff --git a/edgedb/protocol/protocol.pxd b/edgedb/protocol/protocol.pxd index 03f04964..402c988c 100644 --- a/edgedb/protocol/protocol.pxd +++ b/edgedb/protocol/protocol.pxd @@ -67,15 +67,32 @@ cdef enum AuthenticationStatuses: AUTH_SASL_FINAL = 12 -cdef class QueryCodecsCache: - +cdef class ExecuteContext: cdef: - LRUMapping queries - - cdef set(self, str query, OutputFormat output_format, - int implicit_limit, bint inline_typenames, bint inline_typeids, - bint expect_one, bint has_na_cardinality, - BaseCodec in_type, BaseCodec out_type, int capabilities) + # Input arguments + str query + object args + object kwargs + CodecsRegistry reg + LRUMapping qc + OutputFormat output_format + bint expect_one + bint required_one + int implicit_limit + bint inline_typenames + bint inline_typeids + uint64_t allow_capabilities + object state + + # Contextual variables + bytes cardinality + BaseCodec in_dc + BaseCodec out_dc + readonly uint64_t capabilities + + cdef inline bint has_na_cardinality(self) + cdef bint load_from_cache(self) + cdef inline store_to_cache(self) cdef class SansIOProtocol: @@ -113,7 +130,7 @@ cdef class SansIOProtocol: cdef parse_data_messages(self, BaseCodec out_dc, result) cdef parse_sync_message(self) cdef parse_command_complete_message(self) - cdef parse_describe_type_message(self, CodecsRegistry reg) + cdef parse_describe_type_message(self, ExecuteContext ctx) cdef parse_describe_state_message(self) cdef parse_type_data(self, CodecsRegistry reg) cdef _amend_parse_error( @@ -141,17 +158,7 @@ cdef class SansIOProtocol: cdef ensure_connected(self) - cdef WriteBuffer encode_parse_params( - self, - str query, - object output_format, - bint expect_one, - int implicit_limit, - bint inline_typenames, - bint inline_typeids, - uint64_t allow_capabilities, - object state, - ) + cdef WriteBuffer encode_parse_params(self, ExecuteContext ctx) include "protocol_v0.pxd" diff --git a/edgedb/protocol/protocol.pyx b/edgedb/protocol/protocol.pyx index 469d3cb4..3e0643d3 100644 --- a/edgedb/protocol/protocol.pyx +++ b/edgedb/protocol/protocol.pyx @@ -91,44 +91,74 @@ cdef dict OLD_ERROR_CODES = { } -cdef class QueryCodecsCache: - - def __init__(self, *, cache_size=1000): - self.queries = LRUMapping(maxsize=cache_size) - - def get( - self, str query, OutputFormat output_format, - int implicit_limit, bint inline_typenames, bint inline_typeids, - bint expect_one +cdef class ExecuteContext: + def __init__( + self, + *, + query: str, + args, + kwargs, + reg: CodecsRegistry, + qc: LRUMapping, + output_format: OutputFormat, + expect_one: bool = False, + required_one: bool = False, + implicit_limit: int = 0, + inline_typenames: bool = False, + inline_typeids: bool = False, + allow_capabilities: enums.Capability = enums.Capability.ALL, + state: typing.Optional[dict] = None, ): + self.query = query + self.args = args + self.kwargs = kwargs + self.reg = reg + self.qc = qc + self.output_format = output_format + self.expect_one = bool(expect_one) + self.required_one = bool(required_one) + self.implicit_limit = implicit_limit + self.inline_typenames = bool(inline_typenames) + self.inline_typeids = bool(inline_typeids) + self.allow_capabilities = allow_capabilities + self.state = state + + self.cardinality = None + self.in_dc = self.out_dc = None + self.capabilities = 0 + + cdef inline bint has_na_cardinality(self): + return self.cardinality == CARDINALITY_NOT_APPLICABLE + + cdef bint load_from_cache(self): key = ( - query, - output_format, - implicit_limit, - inline_typenames, - inline_typeids, - expect_one, + self.query, + self.output_format, + self.implicit_limit, + self.inline_typenames, + self.inline_typeids, + self.expect_one, ) - return self.queries.get(key, None) + rv = self.qc.get(key, None) + if rv is None: + return False + else: + self.cardinality, self.in_dc, self.out_dc, self.capabilities = rv + return True - cdef set( - self, str query, OutputFormat output_format, - int implicit_limit, bint inline_typenames, bint inline_typeids, - bint expect_one, bint has_na_cardinality, - BaseCodec in_type, BaseCodec out_type, int capabilities, - ): + cdef inline store_to_cache(self): + assert self.in_dc is not None + assert self.out_dc is not None key = ( - query, - output_format, - implicit_limit, - inline_typenames, - inline_typeids, - expect_one, + self.query, + self.output_format, + self.implicit_limit, + self.inline_typenames, + self.inline_typeids, + self.expect_one, ) - assert in_type is not None - assert out_type is not None - self.queries[key] = ( - has_na_cardinality, in_type, out_type, capabilities + self.qc[key] = ( + self.cardinality, self.in_dc, self.out_dc, self.capabilities ) @@ -209,63 +239,37 @@ cdef class SansIOProtocol: raise errors.ClientConnectionClosedError( 'the connection has been closed') - cdef WriteBuffer encode_parse_params( - self, - str query, - object output_format, - bint expect_one, - int implicit_limit, - bint inline_typenames, - bint inline_typeids, - uint64_t allow_capabilities, - object state, - ): + cdef WriteBuffer encode_parse_params(self, ExecuteContext ctx): cdef: WriteBuffer buf compilation_flags = enums.CompilationFlag.INJECT_OUTPUT_OBJECT_IDS - if inline_typenames: + if ctx.inline_typenames: compilation_flags |= enums.CompilationFlag.INJECT_OUTPUT_TYPE_NAMES - if inline_typeids: + if ctx.inline_typeids: compilation_flags |= enums.CompilationFlag.INJECT_OUTPUT_TYPE_IDS buf = WriteBuffer.new() - buf.write_int64(allow_capabilities) + buf.write_int64(ctx.allow_capabilities) buf.write_int64(compilation_flags) - buf.write_int64(implicit_limit) - buf.write_byte(output_format) - buf.write_byte(CARDINALITY_ONE if expect_one else CARDINALITY_MANY) - buf.write_len_prefixed_utf8(query) + buf.write_int64(ctx.implicit_limit) + buf.write_byte(ctx.output_format) + buf.write_byte(CARDINALITY_ONE if ctx.expect_one else CARDINALITY_MANY) + buf.write_len_prefixed_utf8(ctx.query) - state_type_id, state_data = self.encode_state(state) + state_type_id, state_data = self.encode_state(ctx.state) buf.write_bytes(state_type_id) buf.write_bytes(state_data) return buf - async def _parse( - self, - query: str, - *, - reg: CodecsRegistry, - output_format: OutputFormat=OutputFormat.BINARY, - expect_one: bint=False, - required_one: bool=False, - implicit_limit: int=0, - inline_typenames: bool=False, - inline_typeids: bool=False, - allow_capabilities: enums.Capability = enums.Capability.ALL, - state: typing.Optional[dict] = None, - ): + async def _parse(self, ctx: ExecuteContext): cdef: WriteBuffer buf, params char mtype - BaseCodec in_dc = None - BaseCodec out_dc = None int16_t type_size bytes in_type_id bytes out_type_id - bytes cardinality if not self.connected: raise RuntimeError('not connected') @@ -273,16 +277,7 @@ cdef class SansIOProtocol: buf = WriteBuffer.new_message(PREPARE_MSG) buf.write_int16(0) # no headers - params = self.encode_parse_params( - query=query, - output_format=output_format, - expect_one=expect_one, - implicit_limit=implicit_limit, - inline_typenames=inline_typenames, - inline_typeids=inline_typeids, - allow_capabilities=allow_capabilities, - state=state, - ) + params = self.encode_parse_params(ctx) buf.write_buffer(params) buf.end_message() @@ -297,17 +292,20 @@ cdef class SansIOProtocol: try: if mtype == STMT_DATA_DESC_MSG: - capabilities, cardinality, in_dc, out_dc = \ - self.parse_describe_type_message(reg) + self.parse_describe_type_message(ctx) elif mtype == STATE_DATA_DESC_MSG: self.parse_describe_state_message() elif mtype == ERROR_RESPONSE_MSG: exc = self.parse_error_message() - exc._query = query + exc._query = ctx.query exc = self._amend_parse_error( - exc, output_format, expect_one, required_one) + exc, + ctx.output_format, + ctx.expect_one, + ctx.required_one, + ) elif mtype == READY_FOR_COMMAND_MSG: self.parse_sync_message() @@ -321,62 +319,32 @@ cdef class SansIOProtocol: if exc is not None: raise exc - if required_one and cardinality == CARDINALITY_NOT_APPLICABLE: - assert output_format != OutputFormat.NONE - methname = _QUERY_SINGLE_METHOD[required_one][output_format] + if ctx.required_one and ctx.has_na_cardinality(): + assert ctx.output_format != OutputFormat.NONE + methname = _QUERY_SINGLE_METHOD[ctx.required_one][ctx.output_format] raise errors.InterfaceError( f'query cannot be executed with {methname}() as it ' f'does not return any data') - return cardinality, in_dc, out_dc, capabilities - - async def _execute( - self, - *, - query: str, - args, - kwargs, - reg: CodecsRegistry, - qc: QueryCodecsCache, - output_format: object, - expect_one: bint, - required_one: bint, - implicit_limit: int, - inline_typenames: bint, - inline_typeids: bint, - allow_capabilities: enums.Capability = enums.Capability.ALL, - in_dc: BaseCodec, - out_dc: BaseCodec, - state: typing.Optional[dict] = None, - ): + async def _execute(self, ctx: ExecuteContext): cdef: WriteBuffer packet WriteBuffer buf WriteBuffer params char mtype object result - bytes new_cardinality = None - - params = self.encode_parse_params( - query=query, - output_format=output_format, - expect_one=expect_one, - implicit_limit=implicit_limit, - inline_typenames=inline_typenames, - inline_typeids=inline_typeids, - allow_capabilities=allow_capabilities, - state=state, - ) + + params = self.encode_parse_params(ctx) buf = WriteBuffer.new_message(EXECUTE_MSG) buf.write_int16(0) # no headers buf.write_buffer(params) - buf.write_bytes(in_dc.get_tid()) - buf.write_bytes(out_dc.get_tid()) + buf.write_bytes(ctx.in_dc.get_tid()) + buf.write_bytes(ctx.out_dc.get_tid()) - self.encode_args(in_dc, buf, args, kwargs) + self.encode_args(ctx.in_dc, buf, ctx.args, ctx.kwargs) buf.end_message() @@ -395,18 +363,8 @@ cdef class SansIOProtocol: try: if mtype == STMT_DATA_DESC_MSG: # our in/out type spec is out-dated - capabilities, new_cardinality, in_dc, out_dc = \ - self.parse_describe_type_message(reg) - - qc.set( - query, - output_format, - implicit_limit, - inline_typenames, - inline_typeids, - expect_one, - new_cardinality == CARDINALITY_NOT_APPLICABLE, - in_dc, out_dc, capabilities) + self.parse_describe_type_message(ctx) + ctx.store_to_cache() elif mtype == STATE_DATA_DESC_MSG: self.parse_describe_state_message() @@ -414,7 +372,7 @@ cdef class SansIOProtocol: elif mtype == DATA_MSG: if exc is None: try: - self.parse_data_messages(out_dc, result) + self.parse_data_messages(ctx.out_dc, result) except Exception as ex: # An error during data decoding. We need to # handle this as gracefully as possible: @@ -436,19 +394,25 @@ cdef class SansIOProtocol: elif mtype == ERROR_RESPONSE_MSG: exc = self.parse_error_message() - exc._query = query + exc._query = ctx.query if exc.get_code() == parameter_type_mismatch_code: - if not isinstance(in_dc, NullCodec): + if not isinstance(ctx.in_dc, NullCodec): buf = WriteBuffer.new() try: - self.encode_args(in_dc, buf, args, kwargs) + self.encode_args( + ctx.in_dc, buf, ctx.args, ctx.kwargs + ) except errors.QueryArgumentError as ex: exc = ex finally: buf = None else: exc = self._amend_parse_error( - exc, output_format, expect_one, required_one) + exc, + ctx.output_format, + ctx.expect_one, + ctx.required_one, + ) elif mtype == READY_FOR_COMMAND_MSG: self.parse_sync_message() @@ -481,153 +445,51 @@ cdef class SansIOProtocol: else: return NULL_CODEC_ID, EMPTY_NULL_DATA - async def execute( - self, - *, - query: str, - args, - kwargs, - reg: CodecsRegistry, - qc: QueryCodecsCache, - output_format: object, - expect_one: bint = False, - required_one: bool = False, - implicit_limit: int = 0, - inline_typenames: bool = False, - inline_typeids: bool = False, - allow_capabilities: enums.Capability = enums.Capability.ALL, - state: typing.Optional[dict] = None, - ): - cdef: - BaseCodec in_dc - BaseCodec out_dc - + async def execute(self, ctx: ExecuteContext): self.ensure_connected() self.reset_status() - codecs = qc.get( - query, - output_format, - implicit_limit, - inline_typenames, - inline_typeids, - expect_one) - - if codecs is not None: - in_dc = codecs[1] - out_dc = codecs[2] - elif not args and not kwargs and not required_one: + if ctx.load_from_cache(): + pass + elif not ctx.args and not ctx.kwargs and not ctx.required_one: # We don't have knowledge about the in/out desc of the command, but # the caller didn't provide any arguments, so let's try using NULL # for both in (assumed) and out (the server will correct it) desc # without an additional Parse, unless required_one is set because # it'll be too late to find out the cardinality is wrong when the # command is already executed. - in_dc = out_dc = NULL_CODEC + ctx.in_dc = ctx.out_dc = NULL_CODEC else: - parsed = await self._parse( - query, - reg=reg, - output_format=output_format, - expect_one=expect_one, - required_one=required_one, - implicit_limit=implicit_limit, - inline_typenames=inline_typenames, - inline_typeids=inline_typeids, - allow_capabilities=allow_capabilities, - state=state, - ) - - has_na_cardinality = parsed[0] == CARDINALITY_NOT_APPLICABLE - in_dc = parsed[1] - out_dc = parsed[2] - capabilities = parsed[3] - - qc.set( - query, - output_format, - implicit_limit, - inline_typenames, - inline_typeids, - expect_one, - has_na_cardinality, - in_dc, - out_dc, - capabilities, - ) - - return await self._execute( - query=query, - args=args, - kwargs=kwargs, - reg=reg, - qc=qc, - output_format=output_format, - expect_one=expect_one, - required_one=required_one, - implicit_limit=implicit_limit, - inline_typenames=inline_typenames, - inline_typeids=inline_typeids, - allow_capabilities=allow_capabilities, - in_dc=in_dc, - out_dc=out_dc, - state=state, - ) + await self._parse(ctx) + ctx.store_to_cache() - async def query( - self, - *, - query: str, - args, - kwargs, - reg: CodecsRegistry, - qc: QueryCodecsCache, - output_format: object, - expect_one: bint = False, - required_one: bool = False, - implicit_limit: int = 0, - inline_typenames: bool = False, - inline_typeids: bool = False, - allow_capabilities: enums.Capability = enums.Capability.ALL, - state: typing.Optional[dict] = None, - ): - ret = await self.execute( - query=query, - args=args, - kwargs=kwargs, - reg=reg, - qc=qc, - output_format=output_format, - expect_one=expect_one, - required_one=required_one, - implicit_limit=implicit_limit, - inline_typenames=inline_typenames, - inline_typeids=inline_typeids, - allow_capabilities=allow_capabilities, - state=state, - ) + return await self._execute(ctx) - if expect_one: - if ret or not required_one: + async def query(self, ctx: ExecuteContext): + ret = await self.execute(ctx) + if ctx.expect_one: + if ret or not ctx.required_one: if ret: return ret[0] else: - if output_format == OutputFormat.JSON: + if ctx.output_format == OutputFormat.JSON: return 'null' else: return None else: - methname = _QUERY_SINGLE_METHOD[required_one][output_format] + methname = ( + _QUERY_SINGLE_METHOD[ctx.required_one][ctx.output_format] + ) raise errors.NoDataError( f'query executed via {methname}() returned no data') else: if ret: - if output_format == OutputFormat.JSON: + if ctx.output_format == OutputFormat.JSON: return ret[0] else: return ret else: - if output_format == OutputFormat.JSON: + if ctx.output_format == OutputFormat.JSON: return '[]' else: return ret @@ -1121,22 +983,17 @@ cdef class SansIOProtocol: (in_dc).encode_args(buf, kwargs) - cdef parse_describe_type_message(self, CodecsRegistry reg): + cdef parse_describe_type_message(self, ExecuteContext ctx): assert self.buffer.get_message_type() == COMMAND_DATA_DESC_MSG - cdef: - bytes cardinality - try: self.ignore_headers() - capabilities = self.buffer.read_int64() - cardinality = self.buffer.read_byte() - in_dc, out_dc = self.parse_type_data(reg) + ctx.capabilities = self.buffer.read_int64() + ctx.cardinality = self.buffer.read_byte() + ctx.in_dc, ctx.out_dc = self.parse_type_data(ctx.reg) finally: self.buffer.finish_message() - return capabilities, cardinality, in_dc, out_dc - cdef parse_describe_state_message(self): assert self.buffer.get_message_type() == STATE_DATA_DESC_MSG try: diff --git a/edgedb/protocol/protocol_v0.pyx b/edgedb/protocol/protocol_v0.pyx index 399f803c..2a4cb80b 100644 --- a/edgedb/protocol/protocol_v0.pyx +++ b/edgedb/protocol/protocol_v0.pyx @@ -225,24 +225,7 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): return result - async def _legacy_optimistic_execute( - self, - *, - query: str, - args, - kwargs, - reg: CodecsRegistry, - qc: QueryCodecsCache, - output_format: object, - expect_one: bint, - required_one: bint, - implicit_limit: int, - inline_typenames: bint, - inline_typeids: bint, - allow_capabilities: typing.Optional[int] = None, - in_dc: BaseCodec, - out_dc: BaseCodec, - ): + async def _legacy_optimistic_execute(self, ctx: ExecuteContext): cdef: WriteBuffer packet WriteBuffer buf @@ -251,6 +234,20 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): object result bytes new_cardinality = None + str query = ctx.query + object args = ctx.args + object kwargs = ctx.kwargs + CodecsRegistry reg = ctx.reg + OutputFormat output_format = ctx.output_format + bint expect_one = ctx.expect_one + bint required_one = ctx.required_one + int implicit_limit = ctx.implicit_limit + bint inline_typenames = ctx.inline_typenames + bint inline_typeids = ctx.inline_typeids + uint64_t allow_capabilities = ctx.allow_capabilities + BaseCodec in_dc = ctx.in_dc + BaseCodec out_dc = ctx.out_dc + buf = WriteBuffer.new_message(EXECUTE_MSG) self.legacy_write_execute_headers( buf, implicit_limit, inline_typenames, inline_typeids, @@ -287,15 +284,12 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): if capabilities is not None: capabilities = int.from_bytes(capabilities, 'big') - qc.set( - query, - output_format, - implicit_limit, - inline_typenames, - inline_typeids, - expect_one, - new_cardinality == CARDINALITY_NOT_APPLICABLE, - in_dc, out_dc, capabilities) + ctx.cardinality = new_cardinality + ctx.in_dc = in_dc + ctx.out_dc = out_dc + ctx.capabilities = capabilities + ctx.store_to_cache() + re_exec = True elif mtype == DATA_MSG: @@ -351,37 +345,27 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): else: return result - async def legacy_execute_anonymous( - self, - *, - query: str, - args, - kwargs, - reg: CodecsRegistry, - qc: QueryCodecsCache, - output_format: object, - expect_one: bint = False, - required_one: bool = False, - implicit_limit: int = 0, - inline_typenames: bool = False, - inline_typeids: bool = False, - allow_capabilities: enums.Capability = enums.Capability.ALL, - ): + async def legacy_execute_anonymous(self, ctx: ExecuteContext): cdef: BaseCodec in_dc BaseCodec out_dc + str query = ctx.query + object args = ctx.args + object kwargs = ctx.kwargs + CodecsRegistry reg = ctx.reg + OutputFormat output_format = ctx.output_format + bint expect_one = ctx.expect_one + bint required_one = ctx.required_one + int implicit_limit = ctx.implicit_limit + bint inline_typenames = ctx.inline_typenames + bint inline_typeids = ctx.inline_typeids + uint64_t allow_capabilities = ctx.allow_capabilities + self.ensure_connected() self.reset_status() - codecs = qc.get( - query, - output_format, - implicit_limit, - inline_typenames, - inline_typeids, - expect_one) - if codecs is None: + if not ctx.load_from_cache(): codecs = await self._legacy_parse( query, reg=reg, @@ -405,48 +389,22 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): if capabilities is not None: capabilities = int.from_bytes(capabilities, 'big') - qc.set( - query, - output_format, - implicit_limit, - inline_typenames, - inline_typeids, - expect_one, - cardinality == CARDINALITY_NOT_APPLICABLE, - in_dc, - out_dc, - capabilities, - ) + ctx.cardinality = cardinality + ctx.in_dc = in_dc + ctx.out_dc = out_dc + ctx.capabilities = capabilities + ctx.store_to_cache() ret = await self._legacy_execute(in_dc, out_dc, args, kwargs) else: - has_na_cardinality = codecs[0] - in_dc = codecs[1] - out_dc = codecs[2] - - if required_one and has_na_cardinality: + if required_one and ctx.has_na_cardinality(): methname = _QUERY_SINGLE_METHOD[required_one][output_format] raise errors.InterfaceError( f'query cannot be executed with {methname}() as it ' f'does not return any data') - ret = await self._legacy_optimistic_execute( - query=query, - args=args, - kwargs=kwargs, - reg=reg, - qc=qc, - output_format=output_format, - expect_one=expect_one, - required_one=required_one, - implicit_limit=implicit_limit, - inline_typenames=inline_typenames, - inline_typeids=inline_typeids, - allow_capabilities=allow_capabilities, - in_dc=in_dc, - out_dc=out_dc, - ) + ret = await self._legacy_optimistic_execute(ctx) if expect_one: if ret or not required_one: diff --git a/tests/test_sync_retry.py b/tests/test_sync_retry.py index 831f0964..ae32c633 100644 --- a/tests/test_sync_retry.py +++ b/tests/test_sync_retry.py @@ -17,7 +17,9 @@ # +import asyncio import threading +import queue import unittest.mock from concurrent import futures @@ -254,3 +256,76 @@ def test_sync_transaction_interface_errors(self): with tx: with tx: pass + + def test_sync_retry_parse(self): + loop = asyncio.new_event_loop() + q = queue.Queue() + + async def init(): + return asyncio.Event(), asyncio.Event() + + reconnect, terminate = loop.run_until_complete(init()) + + async def proxy(r, w): + try: + while True: + buf = await r.read(65536) + if not buf: + w.close() + break + w.write(buf) + except asyncio.CancelledError: + pass + + async def cb(ri, wi): + try: + args = self.get_connect_args() + ro, wo = await asyncio.open_connection( + args["host"], args["port"] + ) + try: + fs = [ + asyncio.create_task(proxy(ri, wo)), + asyncio.create_task(proxy(ro, wi)), + asyncio.create_task(terminate.wait()), + ] + if not reconnect.is_set(): + fs.append(asyncio.create_task(reconnect.wait())) + _, pending = await asyncio.wait( + fs, return_when=asyncio.FIRST_COMPLETED + ) + for f in pending: + f.cancel() + finally: + wo.close() + finally: + wi.close() + + async def proxy_server(): + srv = await asyncio.start_server(cb, host="127.0.0.1", port=0) + try: + q.put(srv.sockets[0].getsockname()[1]) + await terminate.wait() + finally: + srv.close() + await srv.wait_closed() + + with futures.ThreadPoolExecutor(1) as pool: + pool.submit(loop.run_until_complete, proxy_server()) + try: + client = self.make_test_client( + host="127.0.0.1", + port=q.get(), + database=self.get_database_name(), + ) + + # Fill the connection pool with a healthy connection + self.assertEqual(client.query_single("SELECT 42"), 42) + + # Cut the connection to simulate an Internet interruption + loop.call_soon_threadsafe(reconnect.set) + + # Run a new query that was never compiled, retry should work + self.assertEqual(client.query_single("SELECT 1*2+3-4"), 1) + finally: + loop.call_soon_threadsafe(terminate.set) diff --git a/tests/test_sync_tx.py b/tests/test_sync_tx.py index 3ed2fc55..497af782 100644 --- a/tests/test_sync_tx.py +++ b/tests/test_sync_tx.py @@ -102,7 +102,7 @@ def test_sync_transaction_commit_failure(self): def test_sync_transaction_exclusive(self): for tx in self.client.transaction(): with tx: - query = "select sys::_sleep(0.01)" + query = "select sys::_sleep(0.5)" with ThreadPoolExecutor(max_workers=2) as executor: f1 = executor.submit(tx.execute, query) f2 = executor.submit(tx.execute, query)