Skip to content

Commit

Permalink
Make DROP DATABASE properly wait for connections to drop (#6782)
Browse files Browse the repository at this point in the history
DROP DATABASE attempts to wait in a retry loop for all incoming
connections to be dropped, but it doesn't quite work, and so we have
hacked around it with another retry loop in the test suite.

One problem was that we were looking for the *edgedb* database name in
pg_stat_activity, not the *postgres* database name.

The other source of issues was that, as I suspected in #4567, the
connection pool was establishing new connections to the database we
were trying to close. This came in two forms:
 * Transfers scheduled *after* we closed all the connections (since
   now the db was under quota!): this we fix by setting a flag to
   prevent.
 * Transfers that are *already* in flight. This we fix by waiting for
   all of the pending connections to complete, and then closing them
   too.

Once those issues were fixed, there were some failures in the test
suite where DROP failed because of remaining *edgedb* connections.
The problem there turned out to be in the test suite, where we were
cancelling some connections which resulted in non-synchronous closes.

Fixes #4567.
  • Loading branch information
msullivan authored Feb 7, 2024
1 parent 77a66d1 commit 3285fcb
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 32 deletions.
64 changes: 64 additions & 0 deletions edb/common/asyncutil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2018-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations
from typing import *

import asyncio


_T = TypeVar('_T')


async def deferred_shield(arg: Awaitable[_T]) -> _T:
'''Wait for a future, deferring cancellation until it is complete.
If you do
await deferred_shield(something())
it is approximately equivalent to
await something()
except that if the coroutine containing it is cancelled,
something() is protected from cancellation, and *additionally*
CancelledError is not raised in the caller until something()
completes.
This can be useful if something() contains something that
shouldn't be interrupted but also can't be safely left running
asynchronously.
'''
task = asyncio.ensure_future(arg)

ex = None
while not task.done():
try:
await asyncio.shield(task)
except asyncio.CancelledError as cex:
if ex is not None:
cex.__context__ = ex
ex = cex
except Exception:
if ex:
raise ex from None
raise

if ex:
raise ex
return task.result()
48 changes: 35 additions & 13 deletions edb/server/connpool/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class Block(typing.Generic[C]):

querytime_avg: rolavg.RollingAverage
nwaiters_avg: rolavg.RollingAverage
suppressed: bool

_cached_calibrated_demand: float

Expand Down Expand Up @@ -161,6 +162,7 @@ def __init__(

self.querytime_avg = rolavg.RollingAverage(history_size=20)
self.nwaiters_avg = rolavg.RollingAverage(history_size=3)
self.suppressed = False

self._is_log_batching = False
self._last_log_timestamp = 0
Expand Down Expand Up @@ -226,24 +228,17 @@ def try_steal(

return self.conn_stack.popleft()

async def acquire(self) -> C:
# There can be a race between a waiter scheduled for to wake up
# and a connection being stolen (due to quota being enforced,
# for example). In which case the waiter might get finally
# woken up with an empty queue -- hence we use a `while` loop here.
async def try_acquire(self, *, attempts: int = 1) -> typing.Optional[C]:
self.conn_waiters_num += 1
try:
attempts = 0

# Skip the waiters' queue if we can grab a connection from the
# stack immediately - this is not completely fair, but it's
# extremely hard to always take the shortcut and starve the queue
# without blocking the main loop, so we are fine here. (This is
# also how asyncio.Queue is implemented.)
while not self.conn_stack:
if not self.conn_stack:
waiter = self.loop.create_future()

attempts += 1
if attempts > 1:
# If the waiter was woken up only to discover that
# it needs to wait again, we don't want it to lose
Expand Down Expand Up @@ -271,11 +266,26 @@ async def acquire(self) -> C:
self._wakeup_next_waiter()
raise

# There can be a race between a waiter scheduled for to wake up
# and a connection being stolen (due to quota being enforced,
# for example). In which case the waiter might get finally
# woken up with an empty queue -- hence the 'try'.
# acquire will put a while loop around this

# Yield the most recently used connection from the top of the stack
return self.conn_stack.pop()
if self.conn_stack:
return self.conn_stack.pop()
else:
return None
finally:
self.conn_waiters_num -= 1

async def acquire(self) -> C:
attempts = 1
while (c := await self.try_acquire(attempts=attempts)) is None:
attempts += 1
return c

def release(self, conn: C) -> None:
# Put the connection (back) to the top of the stack,
self.conn_stack.append(conn)
Expand Down Expand Up @@ -749,7 +759,7 @@ def _tick(self) -> None:
total_nwaiters += nwaiters
block.nwaiters_avg.add(nwaiters)
nwaiters_avg = block.nwaiters_avg.avg()
if nwaiters_avg:
if nwaiters_avg and not block.suppressed:
# GOTCHA: this is a counter of blocks that need at least 1
# connection. If this number is greater than _max_capacity,
# some block will be starving with zero connection.
Expand Down Expand Up @@ -1023,7 +1033,7 @@ def _find_most_starving_block(
block_size = block.count_conns()
block_demand = block.count_waiters()

if block_size or not block_demand:
if block_size or not block_demand or block.suppressed:
continue

if block_demand > max_need:
Expand All @@ -1039,7 +1049,7 @@ def _find_most_starving_block(
for block in self._blocks.values():
block_size = block.count_conns()
block_quota = block.quota
if block_quota > block_size:
if block_quota > block_size and not block.suppressed:
need = block_quota - block_size
if need > max_need:
max_need = need
Expand All @@ -1052,6 +1062,7 @@ def _find_most_starving_block(

async def _acquire(self, dbname: str) -> C:
block = self._get_block(dbname)
block.suppressed = False

room_for_new_conns = self._cur_capacity < self._max_capacity
block_nconns = block.count_conns()
Expand Down Expand Up @@ -1191,10 +1202,21 @@ async def prune_inactive_connections(self, dbname: str) -> None:
except KeyError:
return None

# Mark the block as suppressed, so that nothing will be
# transferred to it. It will be unsuppressed if anything
# actually tries to connect.
# TODO: Is it possible to safely drop the block?
block.suppressed = True

conns = []
while (conn := block.try_steal()) is not None:
conns.append(conn)

while not block.count_waiters() and block.pending_conns:
# try_acquire, because it can get stolen
if c := await block.try_acquire():
conns.append(c)

if conns:
await asyncio.gather(
*(self._discard_conn(block, conn) for conn in conns),
Expand Down
4 changes: 2 additions & 2 deletions edb/server/tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,13 +749,13 @@ async def ensure_database_not_connected(self, dbname: str) -> None:
)

rloop = retryloop.RetryLoop(
backoff=retryloop.exp_backoff(),
timeout=10.0,
ignore=errors.ExecutionError,
)

async for iteration in rloop:
async with iteration:
# Verify we are disconnected
await self._pg_ensure_database_not_connected(dbname)

async def _pg_ensure_database_not_connected(self, dbname: str) -> None:
Expand All @@ -769,7 +769,7 @@ async def _pg_ensure_database_not_connected(self, dbname: str) -> None:
WHERE
datname = $1
""",
args=[dbname.encode("utf-8")],
args=[self.get_pg_dbname(dbname).encode("utf-8")],
)

if conns:
Expand Down
13 changes: 1 addition & 12 deletions edb/testbase/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,18 +637,7 @@ def _extract_background_errors(metrics: str) -> str | None:


async def drop_db(conn, dbname):
# The connection might not *actually* be closed on the db
# side yet. This is a bug (#4567), but hack around it
# with a retry loop. Without this, tests would flake
# a lot.
async for tr in TestCase.try_until_succeeds(
ignore=(edgedb.ExecutionError, edgedb.ClientConnectionError),
timeout=30
):
async with tr:
await conn.execute(
f'DROP DATABASE {dbname};'
)
await conn.execute(f'DROP DATABASE {dbname}')


class ClusterTestCase(BaseHTTPTestCase):
Expand Down
16 changes: 11 additions & 5 deletions tests/test_server_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from edb.common import devmode
from edb.common import taskgroup as tg
from edb.common import asyncutil
from edb.testbase import server as tb
from edb.server.compiler import enums
from edb.tools import test
Expand Down Expand Up @@ -2282,7 +2283,6 @@ async def test_server_adjacent_extension_propagation(self):

async with tb.start_edgedb_server(**server_args) as sd:

print("SERVER", sd)
await self.con.execute("CREATE EXTENSION notebook;")

# First, ensure that the local server is aware of the new ext.
Expand Down Expand Up @@ -3207,15 +3207,18 @@ async def test_server_proto_concurrent_ddl(self):
try:
async with tg.TaskGroup() as g:
for i, con in enumerate(cons):
g.create_task(con.execute(f'''
# deferred_shield ensures that none of the
# operations get cancelled, which allows us to
# aclose them all cleanly.
g.create_task(asyncutil.deferred_shield(con.execute(f'''
CREATE TYPE {typename_prefix}{i} {{
CREATE REQUIRED PROPERTY prop1 -> std::int64;
}};
INSERT {typename_prefix}{i} {{
prop1 := {i}
}};
'''))
''')))
except tg.TaskGroupError as e:
self.assertIn(
edgedb.TransactionSerializationError,
Expand Down Expand Up @@ -3250,9 +3253,12 @@ async def test_server_proto_concurrent_global_ddl(self):
try:
async with tg.TaskGroup() as g:
for i, con in enumerate(cons):
g.create_task(con.execute(f'''
# deferred_shield ensures that none of the
# operations get cancelled, which allows us to
# aclose them all cleanly.
g.create_task(asyncutil.deferred_shield(con.execute(f'''
CREATE SUPERUSER ROLE concurrent_{i}
'''))
''')))
except tg.TaskGroupError as e:
self.assertIn(
edgedb.TransactionSerializationError,
Expand Down

0 comments on commit 3285fcb

Please sign in to comment.