Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Connection Pool functionality to pool workers #187

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ ignore =
E2
E3
E4
max-line-length = 88
max-line-length = 115
per-file-ignores =
__init__.py: F401
157 changes: 116 additions & 41 deletions aiomultiprocess/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
TypeVar,
)

from .core import get_context, Process
from aiohttp import ClientSession, ClientTimeout, TCPConnector

from .core import Process, get_context
from .scheduler import RoundRobin, Scheduler
from .types import (
LoopInitializer,
Expand Down Expand Up @@ -55,7 +57,11 @@ def __init__(
initargs: Sequence[Any] = (),
loop_initializer: Optional[LoopInitializer] = None,
exception_handler: Optional[Callable[[BaseException], None]] = None,
init_client_session: bool = False,
session_base_url: Optional[str] = None,
) -> None:
self.init_client_session = init_client_session
self.session_base_url = session_base_url
super().__init__(
target=self.run,
initializer=initializer,
Expand All @@ -69,53 +75,116 @@ def __init__(
self.rx = rx

async def run(self) -> None:
"""Pick up work, execute work, return results, rinse, repeat."""
pending: Dict[asyncio.Future, TaskID] = {}
completed = 0
running = True
while running or pending:
# TTL, Tasks To Live, determines how many tasks to execute before dying
if self.ttl and completed >= self.ttl:
running = False

# pick up new work as long as we're "running" and we have open slots
while running and len(pending) < self.concurrency:
try:
task: PoolTask = self.tx.get_nowait()
except queue.Empty:
break

if task is None:
"""Initiate a connection pool, pick up work, execute work, return results, rinse, repeat."""
if self.init_client_session:
async with ClientSession(
connector=TCPConnector(
limit_per_host=max(100, self.concurrency), use_dns_cache=True
),
timeout=ClientTimeout(total=120),
base_url=self.session_base_url if self.session_base_url else None,
) as client_session:
pending: Dict[asyncio.Future, TaskID] = {}
completed = 0
running = True
while running or pending:
# TTL, Tasks To Live, determines how many tasks to execute before dying
if self.ttl and completed >= self.ttl:
running = False

# pick up new work as long as we're "running" and we have open slots
while running and len(pending) < self.concurrency:
try:
task: PoolTask = self.tx.get_nowait()
except queue.Empty:
break

if task is None:
running = False
break

tid, func, args, kwargs = task
# print(f"W/ Session. Args: {args}, and kwargs: {kwargs}")
args = [
*args,
client_session,
] # NOTE: adds client session to the args list
future = asyncio.ensure_future(func(*args, **kwargs))
pending[future] = tid

if not pending:
await asyncio.sleep(0.005)
continue

# return results and/or exceptions when completed
done, _ = await asyncio.wait(
pending.keys(),
timeout=0.05,
return_when=asyncio.FIRST_COMPLETED,
)
for future in done:
tid = pending.pop(future)

result = None
tb = None
try:
result = future.result()
except BaseException as e:
if self.exception_handler is not None:
self.exception_handler(e)

tb = traceback.format_exc()

self.rx.put_nowait((tid, result, tb))
completed += 1
else:
pending: Dict[asyncio.Future, TaskID] = {}
completed = 0
running = True
while running or pending:
# TTL, Tasks To Live, determines how many tasks to execute before dying
if self.ttl and completed >= self.ttl:
running = False
break

tid, func, args, kwargs = task
future = asyncio.ensure_future(func(*args, **kwargs))
pending[future] = tid
# pick up new work as long as we're "running" and we have open slots
while running and len(pending) < self.concurrency:
try:
task: PoolTask = self.tx.get_nowait()
except queue.Empty:
break

if task is None:
running = False
break

tid, func, args, kwargs = task
# print(f"No client session. Args: {args}, and kwargs: {kwargs}")
future = asyncio.ensure_future(func(*args, **kwargs))
pending[future] = tid

if not pending:
await asyncio.sleep(0.005)
continue
if not pending:
await asyncio.sleep(0.005)
continue

# return results and/or exceptions when completed
done, _ = await asyncio.wait(
pending.keys(), timeout=0.05, return_when=asyncio.FIRST_COMPLETED
)
for future in done:
tid = pending.pop(future)
# return results and/or exceptions when completed
done, _ = await asyncio.wait(
pending.keys(), timeout=0.05, return_when=asyncio.FIRST_COMPLETED
)
for future in done:
tid = pending.pop(future)

result = None
tb = None
try:
result = future.result()
except BaseException as e:
if self.exception_handler is not None:
self.exception_handler(e)
result = None
tb = None
try:
result = future.result()
except BaseException as e:
if self.exception_handler is not None:
self.exception_handler(e)

tb = traceback.format_exc()
tb = traceback.format_exc()

self.rx.put_nowait((tid, result, tb))
completed += 1
self.rx.put_nowait((tid, result, tb))
completed += 1


class PoolResult(Awaitable[Sequence[_T]], AsyncIterable[_T]):
Expand Down Expand Up @@ -159,6 +228,8 @@ def __init__(
scheduler: Scheduler = None,
loop_initializer: Optional[LoopInitializer] = None,
exception_handler: Optional[Callable[[BaseException], None]] = None,
init_client_session: bool = False,
session_base_url: Optional[str] = None,
) -> None:
self.context = get_context()

Expand All @@ -175,6 +246,8 @@ def __init__(
self.maxtasksperchild = max(0, maxtasksperchild)
self.childconcurrency = max(1, childconcurrency)
self.exception_handler = exception_handler
self.init_client_session = init_client_session
self.session_base_url = session_base_url

self.processes: Dict[Process, QueueID] = {}
self.queues: Dict[QueueID, Tuple[Queue, Queue]] = {}
Expand Down Expand Up @@ -257,6 +330,8 @@ def create_worker(self, qid: QueueID) -> Process:
initargs=self.initargs,
loop_initializer=self.loop_initializer,
exception_handler=self.exception_handler,
init_client_session=self.init_client_session,
session_base_url=self.session_base_url,
)
process.start()
return process
Expand Down
15 changes: 1 addition & 14 deletions aiomultiprocess/types.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
# Copyright 2022 Amethyst Reese
# Licensed under the MIT license

import multiprocessing
from asyncio import BaseEventLoop
from typing import (
Any,
Callable,
Dict,
NamedTuple,
NewType,
Optional,
Sequence,
Tuple,
TypeVar,
)
from typing import Any, Callable, Dict, NamedTuple, NewType, Optional, Sequence, Tuple, TypeVar

T = TypeVar("T")
R = TypeVar("R")
Expand Down