From 3285fcbb5090ed2ee5573edbc88a7a6b1766b5f2 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 7 Feb 2024 14:35:30 -0800 Subject: [PATCH] Make DROP DATABASE properly wait for connections to drop (#6782) DROP DATABASE attempts to wait in a retry loop for all incoming connections to be dropped, but it doesn't quite work, and so we have hacked around it with another retry loop in the test suite. One problem was that we were looking for the *edgedb* database name in pg_stat_activity, not the *postgres* database name. The other source of issues was that, as I suspected in #4567, the connection pool was establishing new connections to the database we were trying to close. This came in two forms: * Transfers scheduled *after* we closed all the connections (since now the db was under quota!): this we fix by setting a flag to prevent. * Transfers that are *already* in flight. This we fix by waiting for all of the pending connections to complete, and then closing them too. Once those issues were fixed, there were some failures in the test suite where DROP failed because of remaining *edgedb* connections. The problem there turned out to be in the test suite, where we were cancelling some connections which resulted in non-synchronous closes. Fixes #4567. --- edb/common/asyncutil.py | 64 +++++++++++++++++++++++++++++++++++++ edb/server/connpool/pool.py | 48 ++++++++++++++++++++-------- edb/server/tenant.py | 4 +-- edb/testbase/server.py | 13 +------- tests/test_server_proto.py | 16 +++++++--- 5 files changed, 113 insertions(+), 32 deletions(-) create mode 100644 edb/common/asyncutil.py diff --git a/edb/common/asyncutil.py b/edb/common/asyncutil.py new file mode 100644 index 00000000000..7461ec55a50 --- /dev/null +++ b/edb/common/asyncutil.py @@ -0,0 +1,64 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2018-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from __future__ import annotations +from typing import * + +import asyncio + + +_T = TypeVar('_T') + + +async def deferred_shield(arg: Awaitable[_T]) -> _T: + '''Wait for a future, deferring cancellation until it is complete. + + If you do + await deferred_shield(something()) + + it is approximately equivalent to + await something() + + except that if the coroutine containing it is cancelled, + something() is protected from cancellation, and *additionally* + CancelledError is not raised in the caller until something() + completes. + + This can be useful if something() contains something that + shouldn't be interrupted but also can't be safely left running + asynchronously. + ''' + task = asyncio.ensure_future(arg) + + ex = None + while not task.done(): + try: + await asyncio.shield(task) + except asyncio.CancelledError as cex: + if ex is not None: + cex.__context__ = ex + ex = cex + except Exception: + if ex: + raise ex from None + raise + + if ex: + raise ex + return task.result() diff --git a/edb/server/connpool/pool.py b/edb/server/connpool/pool.py index 6500679935a..28b668bfeac 100644 --- a/edb/server/connpool/pool.py +++ b/edb/server/connpool/pool.py @@ -133,6 +133,7 @@ class Block(typing.Generic[C]): querytime_avg: rolavg.RollingAverage nwaiters_avg: rolavg.RollingAverage + suppressed: bool _cached_calibrated_demand: float @@ -161,6 +162,7 @@ def __init__( self.querytime_avg = rolavg.RollingAverage(history_size=20) self.nwaiters_avg = rolavg.RollingAverage(history_size=3) + self.suppressed = False self._is_log_batching = False self._last_log_timestamp = 0 @@ -226,24 +228,17 @@ def try_steal( return self.conn_stack.popleft() - async def acquire(self) -> C: - # There can be a race between a waiter scheduled for to wake up - # and a connection being stolen (due to quota being enforced, - # for example). In which case the waiter might get finally - # woken up with an empty queue -- hence we use a `while` loop here. + async def try_acquire(self, *, attempts: int = 1) -> typing.Optional[C]: self.conn_waiters_num += 1 try: - attempts = 0 - # Skip the waiters' queue if we can grab a connection from the # stack immediately - this is not completely fair, but it's # extremely hard to always take the shortcut and starve the queue # without blocking the main loop, so we are fine here. (This is # also how asyncio.Queue is implemented.) - while not self.conn_stack: + if not self.conn_stack: waiter = self.loop.create_future() - attempts += 1 if attempts > 1: # If the waiter was woken up only to discover that # it needs to wait again, we don't want it to lose @@ -271,11 +266,26 @@ async def acquire(self) -> C: self._wakeup_next_waiter() raise + # There can be a race between a waiter scheduled for to wake up + # and a connection being stolen (due to quota being enforced, + # for example). In which case the waiter might get finally + # woken up with an empty queue -- hence the 'try'. + # acquire will put a while loop around this + # Yield the most recently used connection from the top of the stack - return self.conn_stack.pop() + if self.conn_stack: + return self.conn_stack.pop() + else: + return None finally: self.conn_waiters_num -= 1 + async def acquire(self) -> C: + attempts = 1 + while (c := await self.try_acquire(attempts=attempts)) is None: + attempts += 1 + return c + def release(self, conn: C) -> None: # Put the connection (back) to the top of the stack, self.conn_stack.append(conn) @@ -749,7 +759,7 @@ def _tick(self) -> None: total_nwaiters += nwaiters block.nwaiters_avg.add(nwaiters) nwaiters_avg = block.nwaiters_avg.avg() - if nwaiters_avg: + if nwaiters_avg and not block.suppressed: # GOTCHA: this is a counter of blocks that need at least 1 # connection. If this number is greater than _max_capacity, # some block will be starving with zero connection. @@ -1023,7 +1033,7 @@ def _find_most_starving_block( block_size = block.count_conns() block_demand = block.count_waiters() - if block_size or not block_demand: + if block_size or not block_demand or block.suppressed: continue if block_demand > max_need: @@ -1039,7 +1049,7 @@ def _find_most_starving_block( for block in self._blocks.values(): block_size = block.count_conns() block_quota = block.quota - if block_quota > block_size: + if block_quota > block_size and not block.suppressed: need = block_quota - block_size if need > max_need: max_need = need @@ -1052,6 +1062,7 @@ def _find_most_starving_block( async def _acquire(self, dbname: str) -> C: block = self._get_block(dbname) + block.suppressed = False room_for_new_conns = self._cur_capacity < self._max_capacity block_nconns = block.count_conns() @@ -1191,10 +1202,21 @@ async def prune_inactive_connections(self, dbname: str) -> None: except KeyError: return None + # Mark the block as suppressed, so that nothing will be + # transferred to it. It will be unsuppressed if anything + # actually tries to connect. + # TODO: Is it possible to safely drop the block? + block.suppressed = True + conns = [] while (conn := block.try_steal()) is not None: conns.append(conn) + while not block.count_waiters() and block.pending_conns: + # try_acquire, because it can get stolen + if c := await block.try_acquire(): + conns.append(c) + if conns: await asyncio.gather( *(self._discard_conn(block, conn) for conn in conns), diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 3a1a2bc3757..659924de0a1 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -749,13 +749,13 @@ async def ensure_database_not_connected(self, dbname: str) -> None: ) rloop = retryloop.RetryLoop( - backoff=retryloop.exp_backoff(), timeout=10.0, ignore=errors.ExecutionError, ) async for iteration in rloop: async with iteration: + # Verify we are disconnected await self._pg_ensure_database_not_connected(dbname) async def _pg_ensure_database_not_connected(self, dbname: str) -> None: @@ -769,7 +769,7 @@ async def _pg_ensure_database_not_connected(self, dbname: str) -> None: WHERE datname = $1 """, - args=[dbname.encode("utf-8")], + args=[self.get_pg_dbname(dbname).encode("utf-8")], ) if conns: diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 44ddf108f46..98534e38e15 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -637,18 +637,7 @@ def _extract_background_errors(metrics: str) -> str | None: async def drop_db(conn, dbname): - # The connection might not *actually* be closed on the db - # side yet. This is a bug (#4567), but hack around it - # with a retry loop. Without this, tests would flake - # a lot. - async for tr in TestCase.try_until_succeeds( - ignore=(edgedb.ExecutionError, edgedb.ClientConnectionError), - timeout=30 - ): - async with tr: - await conn.execute( - f'DROP DATABASE {dbname};' - ) + await conn.execute(f'DROP DATABASE {dbname}') class ClusterTestCase(BaseHTTPTestCase): diff --git a/tests/test_server_proto.py b/tests/test_server_proto.py index 2ee7a53bc85..d773522c7f6 100644 --- a/tests/test_server_proto.py +++ b/tests/test_server_proto.py @@ -27,6 +27,7 @@ from edb.common import devmode from edb.common import taskgroup as tg +from edb.common import asyncutil from edb.testbase import server as tb from edb.server.compiler import enums from edb.tools import test @@ -2282,7 +2283,6 @@ async def test_server_adjacent_extension_propagation(self): async with tb.start_edgedb_server(**server_args) as sd: - print("SERVER", sd) await self.con.execute("CREATE EXTENSION notebook;") # First, ensure that the local server is aware of the new ext. @@ -3207,7 +3207,10 @@ async def test_server_proto_concurrent_ddl(self): try: async with tg.TaskGroup() as g: for i, con in enumerate(cons): - g.create_task(con.execute(f''' + # deferred_shield ensures that none of the + # operations get cancelled, which allows us to + # aclose them all cleanly. + g.create_task(asyncutil.deferred_shield(con.execute(f''' CREATE TYPE {typename_prefix}{i} {{ CREATE REQUIRED PROPERTY prop1 -> std::int64; }}; @@ -3215,7 +3218,7 @@ async def test_server_proto_concurrent_ddl(self): INSERT {typename_prefix}{i} {{ prop1 := {i} }}; - ''')) + '''))) except tg.TaskGroupError as e: self.assertIn( edgedb.TransactionSerializationError, @@ -3250,9 +3253,12 @@ async def test_server_proto_concurrent_global_ddl(self): try: async with tg.TaskGroup() as g: for i, con in enumerate(cons): - g.create_task(con.execute(f''' + # deferred_shield ensures that none of the + # operations get cancelled, which allows us to + # aclose them all cleanly. + g.create_task(asyncutil.deferred_shield(con.execute(f''' CREATE SUPERUSER ROLE concurrent_{i} - ''')) + '''))) except tg.TaskGroupError as e: self.assertIn( edgedb.TransactionSerializationError,