Skip to content

Commit

Permalink
WIP asyncio.Queue
Browse files Browse the repository at this point in the history
  • Loading branch information
carl-baillargeon committed Nov 11, 2024
1 parent 5f65676 commit 756b942
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 15 deletions.
174 changes: 172 additions & 2 deletions anta/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bo
self.established: bool = False
self.cache: Cache | None = None
self.cache_locks: defaultdict[str, asyncio.Lock] | None = None
self.command_queue: asyncio.Queue[AntaCommand] = asyncio.Queue()
self.batch_task: asyncio.Task[None] | None = None
# TODO: Check if we want to make the batch size configurable
self.batch_size: int = 100

# Initialize cache if not disabled
if not disable_cache:
Expand All @@ -104,6 +108,12 @@ def _init_cache(self) -> None:
self.cache = Cache(cache_class=Cache.MEMORY, ttl=60, namespace=self.name, plugins=[HitMissRatioPlugin()])
self.cache_locks = defaultdict(asyncio.Lock)

def init_batch_task(self) -> None:
"""Initialize the batch task for the device."""
if self.batch_task is None:
logger.debug("<%s>: Starting the batch task", self.name)
self.batch_task = asyncio.create_task(self._batch_task())

@property
def cache_statistics(self) -> dict[str, Any] | None:
"""Return the device cache statistics for logging purposes."""
Expand Down Expand Up @@ -137,6 +147,72 @@ def __repr__(self) -> str:
f"disable_cache={self.cache is None!r})"
)

async def _batch_task(self) -> None:
"""Background task to retrieve commands put by tests from the command queue of this device.
Test coroutines put their AntaCommand instances in the queue, this task retrieves them. Once they stop coming,
the instances are grouped by UID, split into JSON and text batches, and collected in batches of `batch_size`.
"""
collection_tasks: list[asyncio.Task[None]] = []
all_commands: list[AntaCommand] = []

while True:
try:
get_await = self.command_queue.get()
command = await asyncio.wait_for(get_await, timeout=0.5)
logger.debug("<%s>: Command retrieved from the queue: %s", self.name, command)
all_commands.append(command)
except asyncio.TimeoutError: # noqa: PERF203
logger.debug("<%s>: All test commands have been retrieved from the queue", self.name)
break

# Group all command instances by UID
command_groups: defaultdict[str, list[AntaCommand]] = defaultdict(list[AntaCommand])
for command in all_commands:
command_groups[command.uid].append(command)

# Split into JSON and text batches. We can safely take the first command instance from each UID as they are the same.
json_commands = {uid: commands for uid, commands in command_groups.items() if commands[0].ofmt == "json"}
text_commands = {uid: commands for uid, commands in command_groups.items() if commands[0].ofmt == "text"}

# Process JSON batches
for i in range(0, len(json_commands), self.batch_size):
batch = dict(list(json_commands.items())[i : i + self.batch_size])
task = asyncio.create_task(self._collect_batch(batch, ofmt="json"))
collection_tasks.append(task)

# Process text batches
for i in range(0, len(text_commands), self.batch_size):
batch = dict(list(text_commands.items())[i : i + self.batch_size])
task = asyncio.create_task(self._collect_batch(batch, ofmt="text"))
collection_tasks.append(task)

# Wait for all collection tasks to complete
if collection_tasks:
logger.debug("<%s>: Waiting for %d collection tasks to complete", self.name, len(collection_tasks))
await asyncio.gather(*collection_tasks)

# TODO: Handle other exceptions

logger.debug("<%s>: Stopping the batch task", self.name)

async def _collect_batch(self, command_groups: dict[str, list[AntaCommand]], ofmt: Literal["json", "text"] = "json") -> None:
"""Collect a batch of device commands.
This coroutine must be implemented by subclasses that want to support command queuing
in conjunction with the `_batch_task()` method.
Parameters
----------
command_groups
Mapping of command instances grouped by UID to avoid duplicate commands.
ofmt
The output format of the batch.
"""
_ = (command_groups, ofmt)
msg = f"_collect_batch method has not been implemented in {self.__class__.__name__} definition"
raise NotImplementedError(msg)

@abstractmethod
async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None:
"""Collect device command output.
Expand Down Expand Up @@ -192,16 +268,38 @@ async def collect(self, command: AntaCommand, *, collection_id: str | None = Non
else:
await self._collect(command=command, collection_id=collection_id)

async def collect_commands(self, commands: list[AntaCommand], *, collection_id: str | None = None) -> None:
async def collect_commands(self, commands: list[AntaCommand], *, command_queuing: bool = False, collection_id: str | None = None) -> None:
"""Collect multiple commands.
Parameters
----------
commands
The commands to collect.
command_queuing
If True, the commands are put in a queue and collected in batches. Default is False.
collection_id
An identifier used to build the eAPI request ID.
An identifier used to build the eAPI request ID. Not used when command queuing is enabled.
"""
# Collect the commands with queuing
if command_queuing:
# Disable cache for this device as it is not needed when using command queuing
self.cache = None
self.cache_locks = None

# Initialize the device batch task if not already running
self.init_batch_task()

# Put the commands in the queue
for command in commands:
logger.debug("<%s>: Putting command in the queue: %s", self.name, command)
await self.command_queue.put(command)

# Wait for all commands to be collected.
logger.debug("<%s>: Waiting for all commands to be collected", self.name)
await asyncio.gather(*[command.event.wait() for command in commands])
return

# Collect the commands without queuing. Default behavior.
await asyncio.gather(*(self.collect(command=command, collection_id=collection_id) for command in commands))

@abstractmethod
Expand Down Expand Up @@ -372,6 +470,78 @@ def _keys(self) -> tuple[Any, ...]:
"""
return (self._session.host, self._session.port)

async def _collect_batch(self, command_groups: dict[str, list[AntaCommand]], ofmt: Literal["json", "text"] = "json") -> None: # noqa: C901
"""Collect a batch of device commands.
Parameters
----------
command_groups
Mapping of command instances grouped by UID to avoid duplicate commands.
ofmt
The output format of the batch.
"""
# Add 'enable' command if required
cmds = []
if self.enable and self._enable_password is not None:
cmds.append({"cmd": "enable", "input": str(self._enable_password)})
elif self.enable:
# No password
cmds.append({"cmd": "enable"})

# Take first instance from each group for the actual commands
cmds.extend(
[
{"cmd": instances[0].command, "revision": instances[0].revision} if instances[0].revision else {"cmd": instances[0].command}
for instances in command_groups.values()
]
)

try:
response = await self._session.cli(
commands=cmds,
ofmt=ofmt,
# TODO: See if we want to have different batches for different versions
version=1,
# TODO: See if want to have a different req_id for each batch
req_id=f"ANTA-{id(command_groups)}",
)

# Do not keep response of 'enable' command
if self.enable:
response = response[1:]

# Update all AntaCommand instances with their output and signal their completion
logger.debug("<%s>: Collected batch of commands, signaling their completion", self.name)
for idx, instances in enumerate(command_groups.values()):
output = response[idx]
for cmd_instance in instances:
cmd_instance.output = output
cmd_instance.event.set()

except asynceapi.EapiCommandError as e:
# TODO: Handle commands that passed
for instances in command_groups.values():
for cmd_instance in instances:
cmd_instance.errors = e.errors
if cmd_instance.requires_privileges:
logger.error(
"Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.",
cmd_instance.command,
self.name,
)
if cmd_instance.supported:
logger.error("Command '%s' failed on %s: %s", cmd_instance.command, self.name, e.errors[0] if len(e.errors) == 1 else e.errors)
else:
logger.debug("Command '%s' is not supported on '%s' (%s)", cmd_instance.command, self.name, self.hw_model)
cmd_instance.event.set()

# TODO: Handle other exceptions
except Exception as e:
for instances in command_groups.values():
for cmd_instance in instances:
cmd_instance.errors = [exc_to_str(e)]
cmd_instance.event.set()

async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: # noqa: C901 function is too complex - because of many required except blocks
"""Collect device command output from EOS using aio-eapi.
Expand Down
37 changes: 26 additions & 11 deletions anta/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

from __future__ import annotations

import asyncio
import hashlib
import logging
import re
from abc import ABC, abstractmethod
from functools import wraps
from functools import cached_property, wraps
from string import Formatter
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypeVar

Expand Down Expand Up @@ -165,7 +166,9 @@ class AntaCommand(BaseModel):
Pydantic Model containing the variables values used to render the template.
use_cache
Enable or disable caching for this AntaCommand if the AntaDevice supports it.
event
Event to signal that the command has been collected. Used by an AntaDevice to signal an AntaTest that the command has been collected.
Only relevant when an AntaTest runs with `command_queuing=True`.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
Expand All @@ -179,13 +182,13 @@ class AntaCommand(BaseModel):
errors: list[str] = []
params: AntaParamsBaseModel = AntaParamsBaseModel()
use_cache: bool = True
event: asyncio.Event | None = None

@property
@cached_property
def uid(self) -> str:
"""Generate a unique identifier for this command."""
uid_str = f"{self.command}_{self.version}_{self.revision or 'NA'}_{self.ofmt}"
# Ignoring S324 probable use of insecure hash function - sha1 is enough for our needs.
return hashlib.sha1(uid_str.encode()).hexdigest() # noqa: S324
return hashlib.sha256(uid_str.encode()).hexdigest()

@property
def json_output(self) -> dict[str, Any]:
Expand Down Expand Up @@ -409,6 +412,8 @@ def __init__(
device: AntaDevice,
inputs: dict[str, Any] | AntaTest.Input | None = None,
eos_data: list[dict[Any, Any] | str] | None = None,
*,
command_queuing: bool = False,
) -> None:
"""AntaTest Constructor.
Expand All @@ -421,10 +426,14 @@ def __init__(
eos_data
Populate outputs of the test commands instead of collecting from devices.
This list must have the same length and order than the `instance_commands` instance attribute.
command_queuing
If True, the commands of this test will be queued in the device command queue and be sent in batches.
Default is False, which means the commands will be sent one by one to the device.
"""
self.logger: logging.Logger = logging.getLogger(f"{self.module}.{self.__class__.__name__}")
self.device: AntaDevice = device
self.inputs: AntaTest.Input
self.command_queuing = command_queuing
self.instance_commands: list[AntaCommand] = []
self.result: TestResult = TestResult(
name=device.name,
Expand Down Expand Up @@ -474,10 +483,17 @@ def _init_commands(self, eos_data: list[dict[Any, Any] | str] | None) -> None:
if self.__class__.commands:
for cmd in self.__class__.commands:
if isinstance(cmd, AntaCommand):
self.instance_commands.append(cmd.model_copy())
command = cmd.model_copy()
if self.command_queuing:
command.event = asyncio.Event()
self.instance_commands.append(command)
elif isinstance(cmd, AntaTemplate):
try:
self.instance_commands.extend(self.render(cmd))
rendered_commands = self.render(cmd)
if self.command_queuing:
for command in rendered_commands:
command.event = asyncio.Event()
self.instance_commands.extend(rendered_commands)
except AntaTemplateRenderError as e:
self.result.is_error(message=f"Cannot render template {{{e.template}}}")
return
Expand Down Expand Up @@ -568,7 +584,7 @@ async def collect(self) -> None:
"""Collect outputs of all commands of this test class from the device of this test instance."""
try:
if self.blocked is False:
await self.device.collect_commands(self.instance_commands, collection_id=self.name)
await self.device.collect_commands(self.instance_commands, collection_id=self.name, command_queuing=self.command_queuing)
except Exception as e: # noqa: BLE001
# device._collect() is user-defined code.
# We need to catch everything if we want the AntaTest object
Expand All @@ -593,7 +609,6 @@ def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]:
async def wrapper(
self: AntaTest,
eos_data: list[dict[Any, Any] | str] | None = None,
**kwargs: dict[str, Any],
) -> TestResult:
"""Inner function for the anta_test decorator.
Expand Down Expand Up @@ -640,7 +655,7 @@ async def wrapper(
return self.result

try:
function(self, **kwargs)
function(self)
except Exception as e: # noqa: BLE001
# test() is user-defined code.
# We need to catch everything if we want the AntaTest object
Expand All @@ -662,7 +677,7 @@ def update_progress(cls: type[AntaTest]) -> None:
cls.progress.update(cls.nrfu_task, advance=1)

@abstractmethod
def test(self) -> Coroutine[Any, Any, TestResult]:
def test(self) -> None:
"""Core of the test logic.
This is an abstractmethod that must be implemented by child classes.
Expand Down
17 changes: 15 additions & 2 deletions anta/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@
logger = logging.getLogger(__name__)

DEFAULT_NOFILE = 16384
COMMAND_QUEUING = False


def get_command_queuing() -> bool:
"""Return the command queuing flag from the environment variable if set."""
try:
command_queuing = bool(os.environ.get("ANTA_COMMAND_QUEUING", COMMAND_QUEUING))
except ValueError as exception:
logger.warning("The ANTA_COMMAND_QUEUING environment variable value is invalid: %s\nDefault to %s.", exc_to_str(exception), COMMAND_QUEUING)
command_queuing = COMMAND_QUEUING
return command_queuing


def adjust_rlimit_nofile() -> tuple[int, int]:
Expand Down Expand Up @@ -190,11 +201,12 @@ def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinitio
list[Coroutine[Any, Any, TestResult]]
The list of coroutines to run.
"""
command_queuing = get_command_queuing()
coros = []
for device, test_definitions in selected_tests.items():
for test in test_definitions:
try:
test_instance = test.test(device=device, inputs=test.inputs)
test_instance = test.test(device=device, inputs=test.inputs, command_queuing=command_queuing)
manager.add(test_instance.result)
coros.append(test_instance.test())
except Exception as e: # noqa: PERF203, BLE001
Expand Down Expand Up @@ -296,4 +308,5 @@ async def main( # noqa: PLR0913
with Catchtime(logger=logger, message="Running ANTA tests"):
await asyncio.gather(*coroutines)

log_cache_statistics(selected_inventory.devices)
if not get_command_queuing():
log_cache_statistics(selected_inventory.devices)

0 comments on commit 756b942

Please sign in to comment.