Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test infrastructure improvements #7628

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions edb/testbase/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2440,7 +2440,7 @@ async def __aenter__(self):
self.proc: asyncio.Process = await asyncio.create_subprocess_exec(
*cmd,
env=env,
stdout=subprocess.PIPE if not self.debug else None,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
pass_fds=(status_w.fileno(),),
)
Expand All @@ -2451,6 +2451,29 @@ async def __aenter__(self):
timeout=240,
),
)

output = b''

async def read_stdout():
nonlocal output
# Tee the log temporarily to a tempfile that exists as long as the
# test is running. This helps debug hanging tests.
with tempfile.NamedTemporaryFile(
mode='w+t',
prefix='edgedb-test-log-') as temp_file:
if self.debug:
print(f"Logging to {temp_file.name}")
while True:
line = await self.proc.stdout.readline()
if not line:
break
output += line
temp_file.write(line.decode(errors='ignore'))
if self.debug:
print(line.decode(errors='ignore'), end='')

stdout_task = asyncio.create_task(read_stdout())

try:
_, pending = await asyncio.wait(
[
Expand All @@ -2476,8 +2499,8 @@ async def __aenter__(self):
await asyncio.wait(pending, timeout=10)

if self.proc.returncode is not None:
output = (await self.proc.stdout.read()).decode().strip()
raise edgedb_cluster.ClusterError(output)
await stdout_task
raise edgedb_cluster.ClusterError(output.decode(errors='ignore'))
else:
assert status_task.done()
data = status_task.result()
Expand Down
4 changes: 3 additions & 1 deletion edb/tools/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from edb.testbase.server import get_test_cases
from edb.tools.edb import edbcommands

from .decorators import async_timeout
from .decorators import not_implemented
from .decorators import _xfail
from .decorators import xfail
Expand All @@ -49,7 +50,8 @@
from . import results


__all__ = ('not_implemented', 'xerror', 'xfail', '_xfail', 'skip')
__all__ = ('async_timeout', 'not_implemented', 'xerror', 'xfail', '_xfail',
'skip')


@edbcommands.command()
Expand Down
18 changes: 18 additions & 0 deletions edb/tools/test/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from __future__ import annotations

import asyncio
import functools
import unittest


Expand Down Expand Up @@ -55,3 +57,19 @@ def decorator(test_item):
return unittest.expectedFailure(test_item)

return decorator


def async_timeout(timeout: int):
def decorator(test_func):
@functools.wraps(test_func)
async def wrapper(*args, **kwargs):
try:
await asyncio.wait_for(test_func(*args, **kwargs), timeout)
except asyncio.TimeoutError:
raise AssertionError(
f"Test failed due to timeout after {timeout} seconds")
except asyncio.CancelledError as e:
raise AssertionError(
f"Test failed due to timeout after {timeout} seconds", e)
return wrapper
return decorator
21 changes: 14 additions & 7 deletions tests/test_server_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

from edb.server import connpool
from edb.server.connpool import pool as pool_impl
from edb.tools.test import async_timeout

# TIME_SCALE is used to run the simulation for longer time, the default is 1x.
TIME_SCALE = int(os.environ.get("TIME_SCALE", '1'))
Expand Down Expand Up @@ -1331,6 +1332,7 @@ async def q2(pool, event):
event.set()
pool.release('aaa', conn)

@async_timeout(timeout=5)
async def test(delay: float):
event = asyncio.Event()

Expand All @@ -1348,8 +1350,8 @@ async def test(delay: float):
g.create_task(q2(pool, event))

async def main():
await asyncio.wait_for(test(0.05), timeout=5)
await asyncio.wait_for(test(0.000001), timeout=5)
await test(0.05)
await test(0.000001)

asyncio.run(main())

Expand All @@ -1364,6 +1366,7 @@ async def q(db, pool, *, wait_event=None, set_event=None):
# print('RELEASE', db)
pool.release(db, conn)

@async_timeout(timeout=5)
async def test(delay: float):
e1 = asyncio.Event()
e2 = asyncio.Event()
Expand All @@ -1388,7 +1391,7 @@ async def test(delay: float):
e1.set()

async def main():
await asyncio.wait_for(test(0.05), timeout=5)
await test(0.05)

asyncio.run(main())

Expand All @@ -1410,6 +1413,7 @@ def _log(self, level, msg, args, *other, **kwargs):
@unittest.mock.patch('edb.server.connpool.pool.MIN_LOG_TIME_THRESHOLD',
0.2)
def test_connpool_log_batching(self, logger: MockLogger):
@async_timeout(timeout=5)
async def test():
pool = connpool.Pool(
connect=self.make_fake_connect(),
Expand Down Expand Up @@ -1446,7 +1450,7 @@ async def test():

async def main():
logger.logs = asyncio.Queue()
await asyncio.wait_for(test(), timeout=5)
await test()

asyncio.run(main())

Expand All @@ -1465,6 +1469,7 @@ async def fake_disconnect(conn):
nonlocal disconnect_called_num
disconnect_called_num += 1

@async_timeout(timeout=1)
async def test():
pool = connpool.Pool(
connect=fake_connect,
Expand All @@ -1481,7 +1486,7 @@ async def test():
self.assertEqual(disconnect_called_num, 0)

async def main():
await asyncio.wait_for(test(), timeout=1)
await test()

asyncio.run(main())

Expand Down Expand Up @@ -1526,6 +1531,7 @@ async def fake_connect(dbname):
else:
return await connect(dbname)

@async_timeout(timeout=5)
async def test():
pool = connpool.Pool(
connect=fake_connect,
Expand All @@ -1552,7 +1558,7 @@ async def test():

async def main():
logger.logs = asyncio.Queue()
await asyncio.wait_for(test(), timeout=5)
await test()

asyncio.run(main())

Expand All @@ -1561,6 +1567,7 @@ async def fake_connect(dbname):
# very fast connect
return FakeConnection(dbname)

@async_timeout(timeout=3)
async def test():
pool = connpool.Pool(
connect=fake_connect,
Expand All @@ -1585,7 +1592,7 @@ async def job(dbname):
g.create_task(job(f"block_{n}"))

async def main():
await asyncio.wait_for(test(), timeout=3)
await test()

asyncio.run(main())

Expand Down
Loading