forked from hummingbot/hummingbot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathisolated_asyncio_wrapper_test_case.py
280 lines (226 loc) · 10.3 KB
/
isolated_asyncio_wrapper_test_case.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import asyncio
import functools
import unittest
from asyncio import Task
from collections.abc import Set
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, TypeVar
T = TypeVar("T")
def async_to_sync(func: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
"""
Decorator to convert an async function into a function that can be called in sync context.
If there's an existing running loop, it uses `run_until_complete()` to execute the coroutine.
Otherwise, it uses `asyncio.run()`.
:param func: The async function to be converted.
:type func: Callable[..., Any]
:return: The wrapped synchronous function.
:rtype: Callable[..., Any]
Usage:
.. code-block:: python
from my_decorators import async_to_sync_in_loop
class MyClass:
@async_to_sync_in_loop
async def async_method(self) -> str:
await asyncio.sleep(1)
return "Hello, World!"
my_instance = MyClass()
result = my_instance.async_method()
print(result) # Output: Hello, World!
"""
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
try:
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
except RuntimeError:
return asyncio.run(func(*args, **kwargs))
result: T = loop.run_until_complete(func(*args, **kwargs))
return result
return wrapper
class IsolatedAsyncioWrapperTestCase(unittest.IsolatedAsyncioTestCase):
"""
Custom test case class that wraps `unittest.IsolatedAsyncioTestCase`.
This class provides additional functionality to set up and tear down the asyncio event loop for each test case.
It ensures that each test case runs in an isolated asyncio event loop, preventing interference between test cases.
Example usage:
```python
class MyTestCase(IsolatedAsyncioWrapperTestCase):
async def test_my_async_function(self):
# Test your async function here
...
```
"""
main_event_loop = None
@classmethod
def setUpClass(cls) -> None:
# Save the current event loop
try:
# This will trigger a RuntimeError if no event loop is running or no event loop is set.
# Meaning, set_event_loop(None) has been called.
cls.main_event_loop = asyncio.get_event_loop()
except RuntimeError:
# If no event loop exists, create one
cls.main_event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls.main_event_loop)
assert cls.main_event_loop is not None
super().setUpClass()
def setUp(self) -> None:
self.local_event_loop = asyncio.get_event_loop()
assert self.local_event_loop is not self.main_event_loop
super().setUp()
def tearDown(self) -> None:
super().tearDown()
if self.main_event_loop is not None and not self.main_event_loop.is_closed():
asyncio.set_event_loop(self.main_event_loop)
assert asyncio.get_event_loop() is self.main_event_loop
else:
asyncio.set_event_loop(asyncio.new_event_loop())
@classmethod
def tearDownClass(cls) -> None:
super().tearDownClass()
# Ok, asyncio.IsolatedAsyncioTestCase kills the main event loop no matter it's initial state.
# We need to restore it here, otherwise any tests after this one will fail if it relies on the main event loop.
if cls.main_event_loop is not None and not cls.main_event_loop.is_closed():
asyncio.set_event_loop(cls.main_event_loop)
else:
asyncio.set_event_loop(asyncio.new_event_loop())
assert asyncio.get_event_loop() is not None
def run_async_with_timeout(self, coroutine: Awaitable, timeout: float = 1.0) -> Any:
"""
Run the given coroutine with a timeout.
:param Awaitable coroutine: The coroutine to be executed.
:param float timeout: The timeout value in seconds.
:return: The result of the coroutine.
:rtype: Any
"""
return self.local_event_loop.run_until_complete(asyncio.wait_for(coroutine, timeout=timeout))
@staticmethod
async def await_task_completion(tasks_name: Optional[str | List[str]]) -> None:
"""
Await the completion of the given task.
Warning: This method relies on undocumented method of Task (get_coro()),
as well as internals of Python coroutines (cr_code.co_name).
:param str tasks_name: The task name (or names) to be awaited.
:return: The result of the task.
"""
def get_coro_func_name(task):
coro = task.get_coro()
return coro.cr_code.co_name
if tasks_name is None:
return
if isinstance(tasks_name, str):
tasks_name = [tasks_name]
tasks: Set[Task] = asyncio.all_tasks()
tasks = {task for task in tasks for task_name in tasks_name if task_name == get_coro_func_name(task)}
if tasks:
await asyncio.wait(tasks)
class LocalClassEventLoopWrapperTestCase(unittest.TestCase):
"""
Custom test case class that wraps `unittest.TestCase`.
This class provides additional functionality to manage the main event loop and a local event loop for tests.
It ensures that each test case runs in a local asyncio event loop, preventing interference between test suites.
Example usage:
```python
class MyTestCase(LocalClassEventLoopWrapperTestCase):
def test_my_async_function(self):
self.local_event_loop.run_until_complete(asyncio.sleep(0.1))
...
```
Note:
- It is important to make sure that all async functions in the test case are prefixed with the `async` keyword.
- This class assumes that the tests are defined as methods in a subclass.
Attributes:
- `main_event_loop`: The reference to the main asyncio event loop.
- `local_event_loop`: The local asyncio event loop used for each test case.
"""
main_event_loop: Optional[asyncio.AbstractEventLoop] = None
local_event_loop: Optional[asyncio.AbstractEventLoop] = None
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
try:
cls.main_event_loop = asyncio.get_event_loop()
except RuntimeError:
cls.main_event_loop = asyncio.new_event_loop()
cls.local_event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls.local_event_loop)
@classmethod
def tearDownClass(cls) -> None:
if cls.local_event_loop is not None:
tasks: Set[Task] = asyncio.all_tasks(cls.local_event_loop)
for task in tasks:
task.cancel()
cls.local_event_loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
cls.local_event_loop.run_until_complete(cls.local_event_loop.shutdown_asyncgens())
cls.local_event_loop.close()
cls.local_event_loop = None
asyncio.set_event_loop(cls.main_event_loop)
cls.main_event_loop = None
super().tearDownClass()
def run_async_with_timeout(self, coroutine: Awaitable, timeout: float = 1.0) -> Any:
"""
Run the given coroutine with a timeout.
:param Awaitable coroutine: The coroutine to be executed.
:param float timeout: The timeout value in seconds.
:return: The result of the coroutine.
:rtype: Any
"""
return self.local_event_loop.run_until_complete(asyncio.wait_for(coroutine, timeout=timeout))
class LocalTestEventLoopWrapperTestCase(unittest.TestCase):
"""
Custom test case class that wraps `unittest.TestCase`.
This class provides additional functionality to manage the main event loop and a local event loop for each test.
It ensures that each test case runs in a local asyncio event loop, preventing interference between test suites.
Example usage:
```python
class MyTestCase(LocalTestEventLoopWrapperTestCase):
def test_my_async_function(self):
self.local_event_loop.run_until_complete(asyncio.sleep(0.1))
...
```
Note:
- It is important to make sure that all async functions in the test case are prefixed with the `async` keyword.
- This class assumes that the tests are defined as methods in a subclass.
Attributes:
- `main_event_loop`: The reference to the main asyncio event loop.
- `local_event_loop`: The local asyncio event loop used for each test case.
"""
main_event_loop: Optional[asyncio.AbstractEventLoop] = None
local_event_loop: Optional[asyncio.AbstractEventLoop] = None
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
try:
cls.main_event_loop = asyncio.get_event_loop()
except RuntimeError:
cls.main_event_loop = None
def setUp(self) -> None:
super().setUp()
self.local_event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.local_event_loop)
self.assertEqual(asyncio.get_event_loop(), self.local_event_loop)
def tearDown(self) -> None:
if self.local_event_loop is not None:
tasks: Set[Task] = asyncio.all_tasks(self.local_event_loop)
for task in tasks:
task.cancel()
self.local_event_loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
self.local_event_loop.run_until_complete(self.local_event_loop.shutdown_asyncgens())
self.local_event_loop.close()
self.local_event_loop = None
super().tearDown()
@classmethod
def tearDownClass(cls) -> None:
if cls.main_event_loop is not None and not cls.main_event_loop.is_closed():
asyncio.set_event_loop(cls.main_event_loop)
else:
asyncio.set_event_loop(asyncio.new_event_loop())
cls.main_event_loop = None
super().tearDownClass()
def run_async_with_timeout(self, coroutine: Awaitable, timeout: float = 1.0) -> Any:
"""
Run the given coroutine with a timeout.
:param Awaitable coroutine: The coroutine to be executed.
:param float timeout: The timeout value in seconds.
:return: The result of the coroutine.
:rtype: Any
"""
return self.local_event_loop.run_until_complete(asyncio.wait_for(coroutine, timeout=timeout))