From 85efd8b211ca851e8fab8491e2e32aa1f8afaaaf Mon Sep 17 00:00:00 2001 From: Arjun Purushothaman Date: Tue, 25 Feb 2025 16:36:25 +0000 Subject: [PATCH] add type annotations and docstrings to devlib Most of the files are covered, but some of the instruments and unused platforms are not augmented --- .gitignore | 4 + devlib/_target_runner.py | 133 +- devlib/collector/__init__.py | 70 +- devlib/collector/dmesg.py | 143 +- devlib/collector/ftrace.py | 334 ++- devlib/collector/logcat.py | 56 +- devlib/collector/perf.py | 244 +- devlib/collector/perfetto.py | 67 +- devlib/collector/screencapture.py | 76 +- devlib/collector/serial_trace.py | 82 +- devlib/collector/systrace.py | 82 +- devlib/connection.py | 531 +++- devlib/exception.py | 30 +- devlib/host.py | 241 +- devlib/instrument/__init__.py | 482 +++- devlib/instrument/acmecape.py | 9 +- devlib/instrument/arm_energy_probe.py | 19 +- devlib/instrument/daq.py | 129 +- devlib/instrument/frames.py | 78 +- devlib/instrument/hwmon.py | 33 +- devlib/module/__init__.py | 255 +- devlib/module/android.py | 64 +- devlib/module/biglittle.py | 262 +- devlib/module/cgroups.py | 385 +-- devlib/module/cgroups2.py | 468 ++- devlib/module/cooling.py | 44 +- devlib/module/cpufreq.py | 271 +- devlib/module/cpuidle.py | 120 +- devlib/module/devfreq.py | 102 +- devlib/module/gpufreq.py | 35 +- devlib/module/hotplug.py | 83 +- devlib/module/hwmon.py | 124 +- devlib/module/sched.py | 234 +- devlib/module/thermal.py | 136 +- devlib/module/vexpress.py | 169 +- devlib/platform/__init__.py | 54 +- devlib/platform/arm.py | 173 +- devlib/target.py | 3824 ++++++++++++++++++------- devlib/utils/android.py | 1036 +++++-- devlib/utils/annotation_helpers.py | 72 + devlib/utils/asyn.py | 553 +++- devlib/utils/gem5.py | 4 +- devlib/utils/misc.py | 659 +++-- devlib/utils/parse_aep.py | 4 +- devlib/utils/rendering.py | 41 +- devlib/utils/serial_port.py | 47 +- devlib/utils/ssh.py | 1076 +++++-- devlib/utils/types.py | 25 +- devlib/utils/uboot.py | 5 +- devlib/utils/uefi.py | 9 +- devlib/utils/version.py | 17 +- mypy.ini | 6 + py.typed | 0 setup.py | 20 +- tests/test_target.py | 13 +- 55 files changed, 9239 insertions(+), 3994 deletions(-) mode change 100755 => 100644 devlib/utils/android.py create mode 100644 devlib/utils/annotation_helpers.py mode change 100755 => 100644 devlib/utils/parse_aep.py create mode 100644 mypy.ini create mode 100644 py.typed diff --git a/.gitignore b/.gitignore index 291b5354d..6ae2b075d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,7 @@ devlib/bin/scripts/shutils doc/_build/ build/ dist/ +.venv/ +.vscode/ +venv/ +.history/ \ No newline at end of file diff --git a/devlib/_target_runner.py b/devlib/_target_runner.py index a45612354..640819c1d 100644 --- a/devlib/_target_runner.py +++ b/devlib/_target_runner.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,15 +17,24 @@ Target runner and related classes are implemented here. """ -import logging import os import time + from platform import machine +from typing import Optional, cast, Protocol, TYPE_CHECKING, Union +from typing_extensions import NotRequired, LiteralString, TypedDict +if TYPE_CHECKING: + from _typeshed import StrPath, BytesPath + from devlib.platform import Platform +else: + StrPath = str + BytesPath = bytes from devlib.exception import (TargetStableError, HostError) -from devlib.target import LinuxTarget -from devlib.utils.misc import get_subprocess, which +from devlib.target import LinuxTarget, Target +from devlib.utils.misc import get_subprocess, which, get_logger from devlib.utils.ssh import SshConnection +from devlib.utils.annotation_helpers import SubprocessCommand, SshUserConnectionSettings class TargetRunner: @@ -36,16 +45,14 @@ class TargetRunner: (e.g., :class:`QEMUTargetRunner`). :param target: Specifies type of target per :class:`Target` based classes. - :type target: Target """ def __init__(self, - target): + target: Target) -> None: self.target = target + self.logger = get_logger(self.__class__.__name__) - self.logger = logging.getLogger(self.__class__.__name__) - - def __enter__(self): + def __enter__(self) -> 'TargetRunner': return self def __exit__(self, *_): @@ -58,18 +65,14 @@ class SubprocessTargetRunner(TargetRunner): :param runner_cmd: The command to start runner process (e.g., ``qemu-system-aarch64 -kernel Image -append "console=ttyAMA0" ...``). - :type runner_cmd: list(str) :param target: Specifies type of target per :class:`Target` based classes. - :type target: Target :param connect: Specifies if :class:`TargetRunner` should try to connect target after launching it, defaults to True. - :type connect: bool or None :param boot_timeout: Timeout for target's being ready for SSH access in seconds, defaults to 60. - :type boot_timeout: int or None :raises HostError: if it cannot execute runner command successfully. @@ -77,10 +80,10 @@ class SubprocessTargetRunner(TargetRunner): """ def __init__(self, - runner_cmd, - target, - connect=True, - boot_timeout=60): + runner_cmd: SubprocessCommand, + target: Target, + connect: bool = True, + boot_timeout: int = 60): super().__init__(target=target) self.boot_timeout = boot_timeout @@ -90,7 +93,7 @@ def __init__(self, try: self.runner_process = get_subprocess(runner_cmd) except Exception as ex: - raise HostError(f'Error while running "{runner_cmd}": {ex}') from ex + raise HostError(f'Error while running "{runner_cmd!r}": {ex}') from ex if connect: self.wait_boot_complete() @@ -107,16 +110,16 @@ def __exit__(self, *_): self.terminate() - def wait_boot_complete(self): + def wait_boot_complete(self) -> None: """ - Wait for target OS to finish boot up and become accessible over SSH in at most - ``SubprocessTargetRunner.boot_timeout`` seconds. + Wait for the target OS to finish booting and become accessible within + :attr:`boot_timeout` seconds. - :raises TargetStableError: In case of timeout. + :raises TargetStableError: If the target is inaccessible after the timeout. """ start_time = time.time() - elapsed = 0 + elapsed: float = 0.0 while self.boot_timeout >= elapsed: try: self.target.connect(timeout=self.boot_timeout - elapsed) @@ -132,9 +135,9 @@ def wait_boot_complete(self): self.terminate() raise TargetStableError(f'Target is inaccessible for {self.boot_timeout} seconds!') - def terminate(self): + def terminate(self) -> None: """ - Terminate ``SubprocessTargetRunner.runner_process``. + Terminate the subprocess associated with this runner. """ self.logger.debug('Killing target runner...') @@ -147,10 +150,9 @@ class NOPTargetRunner(TargetRunner): Class for implementing a target runner which does nothing except providing .target attribute. :param target: Specifies type of target per :class:`Target` based classes. - :type target: Target """ - def __init__(self, target): + def __init__(self, target: Target) -> None: super().__init__(target=target) def __enter__(self): @@ -159,11 +161,61 @@ def __enter__(self): def __exit__(self, *_): pass - def terminate(self): + def terminate(self) -> None: """ Nothing to terminate for NOP target runners. Defined to be compliant with other runners (e.g., ``SubprocessTargetRunner``). """ + pass + + +class QEMUTargetUserSettings(TypedDict, total=False): + kernel_image: str + arch: NotRequired[str] + cpu_type: NotRequired[str] + initrd_image: str + mem_size: NotRequired[int] + num_cores: NotRequired[int] + num_threads: NotRequired[int] + cmdline: NotRequired[str] + enable_kvm: NotRequired[bool] + + +class QEMUTargetRunnerSettings(TypedDict, total=False): + kernel_image: str + arch: str + cpu_type: str + initrd_image: str + mem_size: int + num_cores: int + num_threads: int + cmdline: str + enable_kvm: bool + + +class SshConnectionSettings(TypedDict, total=False): + username: str + password: str + keyfile: Optional[Union[LiteralString, StrPath, BytesPath]] + host: str + port: int + timeout: float + platform: 'Platform' + sudo_cmd: str + strict_host_check: bool + use_scp: bool + poll_transfers: bool + start_transfer_poll_delay: int + total_transfer_timeout: int + transfer_poll_period: int + + +class QEMUTargetRunnerTargetFactory(Protocol): + """ + Protocol for Lambda function for creating :class:`Target` based object. + """ + def __call__(self, *, connect: bool, conn_cls, connection_settings: SshConnectionSettings) -> Target: + ... class QEMUTargetRunner(SubprocessTargetRunner): @@ -177,7 +229,7 @@ class QEMUTargetRunner(SubprocessTargetRunner): * ``arch``: Architecture type. Defaults to ``aarch64``. - * ``cpu_types``: List of CPU ids for QEMU. The list only contains ``cortex-a72`` by + * ``cpu_type``: List of CPU ids for QEMU. The list only contains ``cortex-a72`` by default. This parameter is valid for Arm architectures only. * ``initrd_image``: This points to the location of initrd image (e.g., @@ -197,14 +249,11 @@ class QEMUTargetRunner(SubprocessTargetRunner): * ``enable_kvm``: Specifies if KVM will be used as accelerator in QEMU or not. Enabled by default if host architecture matches with target's for improving QEMU performance. - :type qemu_settings: Dict :param connection_settings: the dictionary to store connection settings of ``Target.connection_settings``, defaults to None. - :type connection_settings: Dict or None :param make_target: Lambda function for creating :class:`Target` based object. - :type make_target: func or None :Variable positional arguments: Forwarded to :class:`TargetRunner`. @@ -212,21 +261,25 @@ class QEMUTargetRunner(SubprocessTargetRunner): """ def __init__(self, - qemu_settings, - connection_settings=None, - make_target=LinuxTarget, - **args): + qemu_settings: QEMUTargetUserSettings, + connection_settings: Optional[SshUserConnectionSettings] = None, + make_target: QEMUTargetRunnerTargetFactory = cast(QEMUTargetRunnerTargetFactory, LinuxTarget), + **args) -> None: - self.connection_settings = { + default_connection_settings = { 'host': '127.0.0.1', 'port': 8022, 'username': 'root', 'password': 'root', 'strict_host_check': False, } - self.connection_settings = {**self.connection_settings, **(connection_settings or {})} - qemu_args = { + self.connection_settings: SshConnectionSettings = cast(SshConnectionSettings, { + **default_connection_settings, + **(connection_settings or {}) + }) + + qemu_default_args = { 'arch': 'aarch64', 'cpu_type': 'cortex-a72', 'mem_size': 512, @@ -235,7 +288,7 @@ def __init__(self, 'cmdline': 'console=ttyAMA0', 'enable_kvm': True, } - qemu_args = {**qemu_args, **qemu_settings} + qemu_args: QEMUTargetRunnerSettings = cast(QEMUTargetRunnerSettings, {**qemu_default_args, **qemu_settings}) qemu_executable = f'qemu-system-{qemu_args["arch"]}' qemu_path = which(qemu_executable) diff --git a/devlib/collector/__init__.py b/devlib/collector/__init__.py index 0bc22ff07..060959414 100644 --- a/devlib/collector/__init__.py +++ b/devlib/collector/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,27 +16,63 @@ import logging from devlib.utils.types import caseless_string +from devlib.utils.misc import get_logger +from typing import TYPE_CHECKING, Optional, List +if TYPE_CHECKING: + from devlib.target import Target + class CollectorBase(object): + """ + The `Collector` API provide a consistent way of collecting arbitrary data from + a target. Data is collected via an instance of a class derived from :class:`CollectorBase`. - def __init__(self, target): + :param target: The devlib Target from which data will be collected. + """ + def __init__(self, target: 'Target'): self.target = target - self.logger = logging.getLogger(self.__class__.__name__) - self.output_path = None - - def reset(self): + self.logger: logging.Logger = get_logger(self.__class__.__name__) + self.output_path: Optional[str] = None + + def reset(self) -> None: + """ + This can be used to configure a collector for collection. This must be invoked + before :meth:`start()` is called to begin collection. + """ pass - def start(self): + def start(self) -> None: + """ + Starts collecting from the target. + """ pass def stop(self): + """ + Stops collecting from target. Must be called after + :func:`start()`. + """ pass - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: + """ + Configure the output path for the particular collector. This will be either + a directory or file path which will be used when storing the data. Please see + the individual Collector documentation for more information. + + :param output_path: The path (file or directory) to which data will be saved. + """ self.output_path = output_path - def get_data(self): + def get_data(self) -> 'CollectorOutput': + """ + The collected data will be return via the previously specified output_path. + This method will return a :class:`CollectorOutput` object which is a subclassed + list object containing individual ``CollectorOutputEntry`` objects with details + about the individual output entry. + + :raises RuntimeError: If ``output_path`` has not been set. + """ return CollectorOutput() def __enter__(self): @@ -47,18 +83,26 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.stop() + class CollectorOutputEntry(object): + """ + This object is designed to allow for the output of a collector to be processed + generically. The object will behave as a regular string containing the path to + underlying output path and can be used directly in ``os.path`` operations. - path_kinds = ['file', 'directory'] + :param path: The file path of the collected output data. + :param path_kind: The type of output. Must be one of ``file`` or ``directory``. + """ + path_kinds: List[str] = ['file', 'directory'] - def __init__(self, path, path_kind): - self.path = path + def __init__(self, path: str, path_kind: str): + self.path = path # path for the corresponding output item path_kind = caseless_string(path_kind) if path_kind not in self.path_kinds: msg = '{} is not a valid path_kind [{}]' raise ValueError(msg.format(path_kind, ' '.join(self.path_kinds))) - self.path_kind = path_kind + self.path_kind = path_kind # file or directory def __str__(self): return self.path diff --git a/devlib/collector/dmesg.py b/devlib/collector/dmesg.py index 06676aaa6..93a1e77f3 100644 --- a/devlib/collector/dmesg.py +++ b/devlib/collector/dmesg.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,10 +21,15 @@ from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.exception import TargetStableError -from devlib.utils.misc import memoized +from devlib.utils.misc import memoized, get_logger +from typing import (Pattern, Optional, Match, Tuple, List, + Union, Any, TYPE_CHECKING) +from collections.abc import Generator +if TYPE_CHECKING: + from devlib.target import Target -_LOGGER = logging.getLogger('dmesg') +_LOGGER: logging.Logger = get_logger('dmesg') class KernelLogEntry(object): @@ -32,28 +37,24 @@ class KernelLogEntry(object): Entry of the kernel ring buffer. :param facility: facility the entry comes from - :type facility: str :param level: log level - :type level: str :param timestamp: Timestamp of the entry - :type timestamp: datetime.timedelta :param msg: Content of the entry - :type msg: str :param line_nr: Line number at which this entry appeared in the ``dmesg`` output. Note that this is not guaranteed to be unique across collectors, as the buffer can be cleared. The timestamp is the only reliable index. - :type line_nr: int """ - _TIMESTAMP_MSG_REGEX = re.compile(r'\[(.*?)\] (.*)') - _RAW_LEVEL_REGEX = re.compile(r'<([0-9]+)>(.*)') - _PRETTY_LEVEL_REGEX = re.compile(r'\s*([a-z]+)\s*:([a-z]+)\s*:\s*(.*)') + _TIMESTAMP_MSG_REGEX: Pattern[str] = re.compile(r'\[(.*?)\] (.*)') + _RAW_LEVEL_REGEX: Pattern[str] = re.compile(r'<([0-9]+)>(.*)') + _PRETTY_LEVEL_REGEX: Pattern[str] = re.compile(r'\s*([a-z]+)\s*:([a-z]+)\s*:\s*(.*)') - def __init__(self, facility, level, timestamp, msg, line_nr=0): + def __init__(self, facility: Optional[str], level: str, + timestamp: timedelta, msg: str, line_nr: int = 0): self.facility = facility self.level = level self.timestamp = timestamp @@ -61,7 +62,7 @@ def __init__(self, facility, level, timestamp, msg, line_nr=0): self.line_nr = line_nr @classmethod - def from_str(cls, line, line_nr=0): + def from_str(cls, line: str, line_nr: int = 0) -> 'KernelLogEntry': """ Parses a "dmesg --decode" output line, formatted as following: kern :err : [3618282.310743] nouveau 0000:01:00.0: systemd-logind[988]: nv50cal_space: -16 @@ -69,10 +70,15 @@ def from_str(cls, line, line_nr=0): Or the more basic output given by "dmesg -r": <3>[3618282.310743] nouveau 0000:01:00.0: systemd-logind[988]: nv50cal_space: -16 + :param line: A string from dmesg. + :param line_nr: The line number in the overall log. + :raises ValueError: If the line format is invalid. + :return: A constructed :class:`KernelLogEntry`. + """ - def parse_raw_level(line): - match = cls._RAW_LEVEL_REGEX.match(line) + def parse_raw_level(line: str) -> Tuple[str, Union[str, Any]]: + match: Optional[Match[str]] = cls._RAW_LEVEL_REGEX.match(line) if not match: raise ValueError(f'dmesg entry format not recognized: {line}') level, remainder = match.groups() @@ -81,15 +87,15 @@ def parse_raw_level(line): level = levels[int(level) % len(levels)] return level, remainder - def parse_pretty_level(line): - match = cls._PRETTY_LEVEL_REGEX.match(line) + def parse_pretty_level(line: str) -> Tuple[str, str, str]: + match: Optional[Match[str]] = cls._PRETTY_LEVEL_REGEX.match(line) if not match: raise ValueError(f'dmesg entry pretty format not recognized: {line}') facility, level, remainder = match.groups() return facility, level, remainder - def parse_timestamp_msg(line): - match = cls._TIMESTAMP_MSG_REGEX.match(line) + def parse_timestamp_msg(line: str) -> Tuple[timedelta, str]: + match: Optional[Match[str]] = cls._TIMESTAMP_MSG_REGEX.match(line) if not match: raise ValueError(f'dmesg entry timestamp format not recognized: {line}') timestamp, msg = match.groups() @@ -101,7 +107,7 @@ def parse_timestamp_msg(line): # If we can parse the raw prio directly, that is a basic line try: level, remainder = parse_raw_level(line) - facility = None + facility: Optional[str] = None except ValueError: facility, level, remainder = parse_pretty_level(line) @@ -116,21 +122,23 @@ def parse_timestamp_msg(line): ) @classmethod - def from_dmesg_output(cls, dmesg_out, error=None): + def from_dmesg_output(cls, dmesg_out: Optional[str], error: Optional[str] = None) -> Generator['KernelLogEntry', None, None]: """ Return a generator of :class:`KernelLogEntry` for each line of the output of dmesg command. + :param dmesg_out: The dmesg output to parse. + :param error: If ``"raise"`` or ``None``, an exception will be raised if a parsing error occurs. If ``"warn"``, it will be logged at WARNING level. If ``"ignore"``, it will be ignored. If a callable is passed, the exception will be passed to it. - :type error: str or None or typing.Callable[[BaseException], None] + :return: A generator of parsed :class:`KernelLogEntry` objects. .. note:: The same restrictions on the dmesg output format as for :meth:`from_str` apply. """ - for i, line in enumerate(dmesg_out.splitlines()): + for i, line in enumerate(dmesg_out.splitlines() if dmesg_out else ''): if line.strip(): try: yield cls.from_str(line, line_nr=i) @@ -160,25 +168,25 @@ class DmesgCollector(CollectorBase): """ Dmesg output collector. + :param target: The devlib Target (must be rooted). + :param level: Minimum log level to enable. All levels that are more critical will be collected as well. - :type level: str :param facility: Facility to record, see dmesg --help for the list. - :type level: str :param empty_buffer: If ``True``, the kernel dmesg ring buffer will be emptied before starting. Note that this will break nesting of collectors, so it's not recommended unless it's really necessary. - :type empty_buffer: bool + :param parse_error: A string to be appended to error lines if parse fails. .. warning:: If BusyBox dmesg is used, facility and level will be ignored, and the parsed entries will also lack that information. """ # taken from "dmesg --help" # This list needs to be ordered by priority - LOG_LEVELS = [ + LOG_LEVELS: List[str] = [ "emerg", # system is unusable "alert", # action must be taken immediately "crit", # critical conditions @@ -189,13 +197,15 @@ class DmesgCollector(CollectorBase): "debug", # debug-level messages ] - def __init__(self, target, level=LOG_LEVELS[-1], facility='kern', empty_buffer=False, parse_error=None): + def __init__(self, target: 'Target', level: str = LOG_LEVELS[-1], + facility: str = 'kern', empty_buffer: bool = False, + parse_error: Optional[str] = None): super(DmesgCollector, self).__init__(target) if not target.is_rooted: raise TargetStableError('Cannot collect dmesg on non-rooted target') - self.output_path = None + self.output_path: Optional[str] = None if level not in self.LOG_LEVELS: raise ValueError('level needs to be one of: {}'.format( @@ -207,42 +217,48 @@ def __init__(self, target, level=LOG_LEVELS[-1], facility='kern', empty_buffer=F # e.g. busybox's dmesg or the one shipped on some Android versions # (toybox). Note: BusyBox dmesg does not support -h, but will still # print the help with an exit code of 1 - help_ = self.target.execute('dmesg -h', check_exit_code=False) - self.basic_dmesg = not all( + help_: str = self.target.execute('dmesg -h', check_exit_code=False) + self.basic_dmesg: bool = not all( opt in help_ for opt in ('--facility', '--force-prefix', '--decode', '--level') ) self.facility = facility try: - needs_root = target.read_sysctl('kernel.dmesg_restrict') + needs_root: bool = target.read_sysctl('kernel.dmesg_restrict') except ValueError: needs_root = True else: needs_root = bool(int(needs_root)) self.needs_root = needs_root - self._begin_timestamp = None - self.empty_buffer = empty_buffer - self._dmesg_out = None - self._parse_error = parse_error + self._begin_timestamp: Optional[timedelta] = None + self.empty_buffer: bool = empty_buffer + self._dmesg_out: Optional[str] = None + self._parse_error: Optional[str] = parse_error @property - def dmesg_out(self): - out = self._dmesg_out + def dmesg_out(self) -> Optional[str]: + """ + Get the dmesg output + """ + out: Optional[str] = self._dmesg_out if out is None: return None else: try: - entry = self.entries[0] + entry: KernelLogEntry = self.entries[0] except IndexError: return '' else: - i = entry.line_nr + i: int = entry.line_nr return '\n'.join(out.splitlines()[i:]) @property - def entries(self): + def entries(self) -> List[KernelLogEntry]: + """ + Get the entries as a list of class:KernelLogEntry + """ return self._get_entries( self._dmesg_out, self._begin_timestamp, @@ -250,14 +266,15 @@ def entries(self): ) @memoized - def _get_entries(self, dmesg_out, timestamp, error): - entries = KernelLogEntry.from_dmesg_output(dmesg_out, error=error) - entries = list(entries) + def _get_entries(self, dmesg_out: Optional[str], timestamp: Optional[timedelta], + error: Optional[str]) -> List[KernelLogEntry]: + entry_ = KernelLogEntry.from_dmesg_output(dmesg_out, error=error) + entries = list(entry_) if timestamp is None: return entries else: try: - first = entries[0] + first: KernelLogEntry = entries[0] except IndexError: pass else: @@ -273,14 +290,17 @@ def _get_entries(self, dmesg_out, timestamp, error): if entry.timestamp > timestamp ] - def _get_output(self): - levels_list = list(takewhile( + def _get_output(self) -> None: + """ + Get the dmesg collector output into _dmesg_out local variable + """ + levels_list: List[str] = list(takewhile( lambda level: level != self.level, self.LOG_LEVELS )) levels_list.append(self.level) if self.basic_dmesg: - cmd = 'dmesg -r' + cmd: str = 'dmesg -r' else: cmd = 'dmesg --facility={facility} --force-prefix --decode --level={levels}'.format( levels=','.join(levels_list), @@ -289,10 +309,17 @@ def _get_output(self): self._dmesg_out = self.target.execute(cmd, as_root=self.needs_root) - def reset(self): + def reset(self) -> None: + """ + Reset the collector's internal state (e.g., cached dmesg output). + """ self._dmesg_out = None - def start(self): + def start(self) -> None: + """ + Start collecting dmesg logs. If ``empty_buffer`` is true, clear them first. + :raises TargetStableError: If the target is not rooted. + """ # If the buffer is emptied on start(), it does not matter as we will # not end up with entries dating from before start() if self.empty_buffer: @@ -307,13 +334,23 @@ def start(self): else: self._begin_timestamp = entry.timestamp - def stop(self): + def stop(self) -> None: + """ + Stop collecting logs and retrieve the latest dmesg output. + """ self._get_output() - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Write the dmesg output to :attr:`output_path` and return a :class:`CollectorOutput`. + + :raises RuntimeError: If :attr:`output_path` is unset. + :return: A collector output referencing the saved dmesg file. + :rtype: CollectorOutput + """ if self.output_path is None: raise RuntimeError("Output path was not set.") with open(self.output_path, 'wt') as f: diff --git a/devlib/collector/ftrace.py b/devlib/collector/ftrace.py index 0aea8eeac..822e341fb 100644 --- a/devlib/collector/ftrace.py +++ b/devlib/collector/ftrace.py @@ -1,4 +1,4 @@ -# Copyright 2015-2018 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,12 +30,19 @@ from devlib.utils.misc import check_output, which, memoized from devlib.utils.asyn import asyncf - -TRACE_MARKER_START = 'TRACE_MARKER_START' -TRACE_MARKER_STOP = 'TRACE_MARKER_STOP' -OUTPUT_TRACE_FILE = 'trace.dat' -OUTPUT_PROFILE_FILE = 'trace_stat.dat' -DEFAULT_EVENTS = [ +from devlib.module.cpufreq import CpufreqModule +from devlib.module.cpuidle import Cpuidle +from typing import (cast, List, Pattern, TYPE_CHECKING, Optional, + Dict, Union, Match) +from devlib.utils.annotation_helpers import BackgroundCommand +if TYPE_CHECKING: + from devlib.target import Target + +TRACE_MARKER_START: str = 'TRACE_MARKER_START' +TRACE_MARKER_STOP: str = 'TRACE_MARKER_STOP' +OUTPUT_TRACE_FILE: str = 'trace.dat' +OUTPUT_PROFILE_FILE: str = 'trace_stat.dat' +DEFAULT_EVENTS: List[str] = [ 'cpu_frequency', 'cpu_idle', 'sched_migrate_task', @@ -46,33 +53,54 @@ 'sched_wakeup', 'sched_wakeup_new', ] -TIMEOUT = 180 +TIMEOUT: int = 180 # Regexps for parsing of function profiling data -CPU_RE = re.compile(r' Function \(CPU([0-9]+)\)') -STATS_RE = re.compile(r'([^ ]*) +([0-9]+) +([0-9.]+) us +([0-9.]+) us +([0-9.]+) us') +CPU_RE: Pattern[str] = re.compile(r' Function \(CPU([0-9]+)\)') +STATS_RE: Pattern[str] = re.compile(r'([^ ]*) +([0-9]+) +([0-9.]+) us +([0-9.]+) us +([0-9.]+) us') -class FtraceCollector(CollectorBase): +class FtraceCollector(CollectorBase): + """ + Collector using ftrace to trace kernel events and functions. + + :param target: The devlib Target (must be rooted). + :param events: A list of events to trace (defaults to `DEFAULT_EVENTS`). + :param functions: A list of functions to trace, if function tracing is used. + :param tracer: The tracer to use (e.g., 'function_graph'), or ``None``. + :param trace_children_functions: If ``True``, trace child functions as well. + :param buffer_size: The size of the trace buffer in KB. + :param top_buffer_size: The top-level buffer size in KB, if different. + :param buffer_size_step: The step size for increasing the buffer. + :param tracing_path: The path to the tracefs mount point, if not auto-detected. + :param automark: If ``True``, automatically mark start and stop in the trace. + :param autoreport: If ``True``, generate a textual trace report automatically. + :param autoview: If ``True``, open KernelShark for a graphical view of the trace. + :param no_install: If ``True``, assume trace-cmd is already installed on target. + :param strict: If ``True``, raise errors if requested events/functions are not available. + :param report_on_target: If ``True``, generate the trace report on the target side. + :param trace_clock: The clock source for the trace. + :param saved_cmdlines_nr: The number of cmdlines to save in the trace buffer. + """ # pylint: disable=too-many-locals,too-many-branches,too-many-statements - def __init__(self, target, - events=None, - functions=None, - tracer=None, - trace_children_functions=False, - buffer_size=None, - top_buffer_size=None, - buffer_size_step=1000, - tracing_path=None, - automark=True, - autoreport=True, - autoview=False, - no_install=False, - strict=False, - report_on_target=False, - trace_clock='local', - saved_cmdlines_nr=4096, - mode='write-to-memory', + def __init__(self, target: 'Target', + events: Optional[List[str]] = None, + functions: Optional[List[str]] = None, + tracer: Optional[str] = None, + trace_children_functions: bool = False, + buffer_size: Optional[int] = None, + top_buffer_size: Optional[int] = None, + buffer_size_step: int = 1000, + tracing_path: Optional[str] = None, + automark: bool = True, + autoreport: bool = True, + autoview: bool = False, + no_install: bool = False, + strict: bool = False, + report_on_target: bool = False, + trace_clock: str = 'local', + saved_cmdlines_nr: int = 4096, + mode: str = 'write-to-memory', ): super(FtraceCollector, self).__init__(target) self.events = events if events is not None else DEFAULT_EVENTS @@ -81,40 +109,40 @@ def __init__(self, target, self.trace_children_functions = trace_children_functions self.buffer_size = buffer_size self.top_buffer_size = top_buffer_size - self.tracing_path = self._resolve_tracing_path(target, tracing_path) + self.tracing_path: str = self._resolve_tracing_path(target, tracing_path) self.automark = automark self.autoreport = autoreport self.autoview = autoview self.strict = strict self.report_on_target = report_on_target - self.target_output_file = target.path.join(self.target.working_directory, OUTPUT_TRACE_FILE) - text_file_name = target.path.splitext(OUTPUT_TRACE_FILE)[0] + '.txt' - self.target_text_file = target.path.join(self.target.working_directory, text_file_name) - self.output_path = None - self.target_binary = None - self.host_binary = None - self.start_time = None - self.stop_time = None - self.function_string = None + self.target_output_file: str = target.path.join(self.target.working_directory, OUTPUT_TRACE_FILE) if target.path else '' + text_file_name: str = target.path.splitext(OUTPUT_TRACE_FILE)[0] + '.txt' if target.path else '' + self.target_text_file: str = target.path.join(self.target.working_directory, text_file_name) if target.path else '' + self.output_path: Optional[str] = None + self.target_binary: Optional[str] = None + self.host_binary: Optional[str] = None + self.start_time: Optional[float] = None + self.stop_time: Optional[float] = None + self.function_string: Optional[str] = None self.trace_clock = trace_clock self.saved_cmdlines_nr = saved_cmdlines_nr - self._reset_needed = True + self._reset_needed: bool = True self.mode = mode - self._bg_cmd = None + self._bg_cmd: Optional[BackgroundCommand] = None # pylint: disable=bad-whitespace # Setup tracing paths - self.available_events_file = self.target.path.join(self.tracing_path, 'available_events') - self.available_functions_file = self.target.path.join(self.tracing_path, 'available_filter_functions') - self.current_tracer_file = self.target.path.join(self.tracing_path, 'current_tracer') - self.function_profile_file = self.target.path.join(self.tracing_path, 'function_profile_enabled') - self.marker_file = self.target.path.join(self.tracing_path, 'trace_marker') - self.ftrace_filter_file = self.target.path.join(self.tracing_path, 'set_ftrace_filter') - self.available_tracers_file = self.target.path.join(self.tracing_path, 'available_tracers') - self.kprobe_events_file = self.target.path.join(self.tracing_path, 'kprobe_events') + self.available_events_file: str = self.target.path.join(self.tracing_path, 'available_events') if self.target.path else '' + self.available_functions_file: str = self.target.path.join(self.tracing_path, 'available_filter_functions') if self.target.path else '' + self.current_tracer_file: str = self.target.path.join(self.tracing_path, 'current_tracer') if self.target.path else '' + self.function_profile_file: str = self.target.path.join(self.tracing_path, 'function_profile_enabled') if self.target.path else '' + self.marker_file: str = self.target.path.join(self.tracing_path, 'trace_marker') if self.target.path else '' + self.ftrace_filter_file: str = self.target.path.join(self.tracing_path, 'set_ftrace_filter') if self.target.path else '' + self.available_tracers_file: str = self.target.path.join(self.tracing_path, 'available_tracers') if self.target.path else '' + self.kprobe_events_file: str = self.target.path.join(self.tracing_path, 'kprobe_events') if self.target.path else '' self.host_binary = which('trace-cmd') - self.kernelshark = which('kernelshark') + self.kernelshark: Optional[str] = which('kernelshark') if not self.target.is_rooted: raise TargetStableError('trace-cmd instrument cannot be used on an unrooted device.') @@ -123,7 +151,7 @@ def __init__(self, target, if self.autoview and self.kernelshark is None: raise HostError('kernelshark binary must be installed on the host if autoview=True.') if not no_install: - host_file = os.path.join(PACKAGE_BIN_DIRECTORY, self.target.abi, 'trace-cmd') + host_file = os.path.join(PACKAGE_BIN_DIRECTORY, self.target.abi or '', 'trace-cmd') self.target_binary = self.target.install(host_file) else: if not self.target.is_installed('trace-cmd'): @@ -131,26 +159,50 @@ def __init__(self, target, self.target_binary = 'trace-cmd' # Validate required events to be traced - def event_to_regex(event): + def event_to_regex(event: str) -> Pattern[str]: + """ + Converts a wildcard-style event name to a compiled regular expression. + + This allows events with '*' wildcards to be matched against actual trace events. + For example, 'sched*' becomes 'sched.*' and can match 'sched_switch', 'sched_wakeup', etc. + + Parameters: + event (str): The event name, potentially containing wildcards. + + Returns: + Pattern[str]: A compiled regular expression that can be used to match trace event names. + """ if not event.startswith('*'): event = '*' + event return re.compile(event.replace('*', '.*')) - def event_is_in_list(event, events): + def event_is_in_list(event: str, events: List[str]) -> bool: + """ + Determines whether a given event matches any of the patterns in a list of trace events. + + Each pattern in the list can contain wildcards and is matched using regex. + + Parameters: + event (str): The event name to test. + events (List[str]): A list of event patterns (may include wildcards). + + Returns: + bool: True if the event matches at least one of the patterns; False otherwise. + """ return any( event_to_regex(event).match(_event) for _event in events ) - available_events = self.available_events - unavailable_events = [ + available_events: List[str] = self.available_events + unavailable_events: List[str] = [ event for event in self.events if not event_is_in_list(event, available_events) ] if unavailable_events: - message = 'Events not available for tracing: {}'.format( + message: str = 'Events not available for tracing: {}'.format( ', '.join(unavailable_events) ) if self.strict: @@ -158,7 +210,7 @@ def event_is_in_list(event, events): else: self.target.logger.warning(message) - selected_events = sorted(set(self.events) - set(unavailable_events)) + selected_events: List[str] = sorted(set(self.events) - set(unavailable_events)) if self.tracer and self.tracer not in self.available_tracers: raise TargetStableError('Unsupported tracer "{}". Available tracers: {}'.format( @@ -167,7 +219,7 @@ def event_is_in_list(event, events): # Check for function tracing support if self.functions: # Validate required functions to be traced - selected_functions = [] + selected_functions: List[str] = [] for function in self.functions: if function not in self.available_functions: message = 'Function [{}] not available for tracing/profiling'.format(function) @@ -180,7 +232,7 @@ def event_is_in_list(event, events): # Function profiling if self.tracer is None: if not self.target.file_exists(self.function_profile_file): - raise TargetStableError('Function profiling not supported. '\ + raise TargetStableError('Function profiling not supported. ' 'A kernel build with CONFIG_FUNCTION_PROFILER enable is required') self.function_string = _build_trace_functions(selected_functions) # If function profiling is enabled we always need at least one event. @@ -205,14 +257,20 @@ def event_string(self): return _build_trace_events(self._selected_events) @classmethod - def _resolve_tracing_path(cls, target, path): + def _resolve_tracing_path(cls, target: 'Target', path: Optional[str]) -> str: + """ + Find path for tracefs + """ if path is None: return cls.find_tracing_path(target) else: return path @classmethod - def find_tracing_path(cls, target): + def find_tracing_path(cls, target: 'Target') -> str: + """ + get tracefs path from mount point + """ fs_list = [ fs.mount_point for fs in target.list_file_systems() @@ -226,14 +284,14 @@ def find_tracing_path(cls, target): @property @memoized - def available_tracers(self): + def available_tracers(self) -> List[str]: """ List of ftrace tracers supported by the target's kernel. """ return self.target.read_value(self.available_tracers_file).split(' ') @property - def available_events(self): + def available_events(self) -> List[str]: """ List of ftrace events supported by the target's kernel. """ @@ -241,16 +299,16 @@ def available_events(self): @property @memoized - def available_functions(self): + def available_functions(self) -> List[str]: """ List of functions whose tracing/profiling is supported by the target's kernel. """ return self.target.read_value(self.available_functions_file).splitlines() - def reset(self): + def reset(self) -> None: # Save kprobe events try: - kprobe_events = self.target.read_value(self.kprobe_events_file) + kprobe_events: Optional[str] = self.target.read_value(self.kprobe_events_file) except TargetStableError: kprobe_events = None @@ -261,10 +319,10 @@ def reset(self): # parameter, but unfortunately some events still end up there (e.g. # print event). So we still need to set that size, otherwise the buffer # might be too small and some event lost. - top_buffer_size = self.top_buffer_size if self.top_buffer_size else self.buffer_size + top_buffer_size: Optional[int] = self.top_buffer_size if self.top_buffer_size else self.buffer_size if top_buffer_size: self.target.write_value( - self.target.path.join(self.tracing_path, 'buffer_size_kb'), + self.target.path.join(self.tracing_path, 'buffer_size_kb') if self.target.path else '', top_buffer_size, verify=False ) @@ -285,7 +343,7 @@ def _trace_frequencies(self): except TargetStableError as e: self.logger.error(f'Could not trace CPUFreq frequencies as the cpufreq module cannot be loaded: {e}') else: - mod.trace_frequencies() + cast(CpufreqModule, mod).trace_frequencies() def _trace_idle(self): if 'cpu_idle' in self._selected_events: @@ -295,20 +353,25 @@ def _trace_idle(self): except TargetStableError as e: self.logger.error(f'Could not trace CPUIdle states as the cpuidle module cannot be loaded: {e}') else: - mod.perturb_cpus() + cast(Cpuidle, mod).perturb_cpus() @asyncf - async def start(self): + async def start(self) -> None: + """ + Start capturing ftrace events according to the selected events/functions. + + :raises TargetStableError: If the target is unrooted or tracing setup fails. + """ self.start_time = time.time() if self._reset_needed: self.reset() if self.tracer is not None and 'function' in self.tracer: - tracecmd_functions = self.function_string + tracecmd_functions: Optional[str] = self.function_string else: tracecmd_functions = '' - tracer_string = '-p {}'.format(self.tracer) if self.tracer else '' + tracer_string: str = '-p {}'.format(self.tracer) if self.tracer else '' # Ensure kallsyms contains addresses if possible, so that function the # collected trace contains enough data for pretty printing @@ -352,18 +415,20 @@ async def start(self): if self.functions and self.tracer is None: target = self.target await target.async_manager.concurrently( - execute.asyn('echo nop > {}'.format(self.current_tracer_file), - as_root=True), - execute.asyn('echo 0 > {}'.format(self.function_profile_file), + target.execute.asyn('echo nop > {}'.format(self.current_tracer_file), as_root=True), - execute.asyn('echo {} > {}'.format(self.function_string, self.ftrace_filter_file), + target.execute.asyn('echo 0 > {}'.format(self.function_profile_file), + as_root=True), # type: ignore + target.execute.asyn('echo {} > {}'.format(self.function_string, self.ftrace_filter_file), as_root=True), - execute.asyn('echo 1 > {}'.format(self.function_profile_file), + target.execute.asyn('echo 1 > {}'.format(self.function_profile_file), as_root=True), ) - - def stop(self): + def stop(self) -> None: + """ + Stop capturing ftrace events. + """ # Disable kernel function profiling if self.functions and self.tracer is None: self.target.execute('echo 0 > {}'.format(self.function_profile_file), @@ -388,16 +453,23 @@ def stop(self): self._reset_needed = True - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: if os.path.isdir(output_path): output_path = os.path.join(output_path, os.path.basename(self.target_output_file)) self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Pull the captured trace data from the target, optionally generate a report, + and return a :class:`CollectorOutput`. + + :raises RuntimeError: If :attr:`output_path` is unset. + :return: A collector output referencing ftrace data. + """ if self.output_path is None: raise RuntimeError("Output path was not set.") - busybox = quote(self.target.busybox) + busybox = quote(self.target.busybox or '') mode = self.mode if mode == 'write-to-disk': @@ -412,7 +484,7 @@ def get_data(self): # The size of trace.dat will depend on how long trace-cmd was running. # Therefore timout for the pull command must also be adjusted # accordingly. - pull_timeout = 10 * (self.stop_time - self.start_time) + pull_timeout: float = 10 * (cast(float, self.stop_time) - cast(float, self.start_time)) self.target.pull(self.target_output_file, self.output_path, timeout=pull_timeout) output = CollectorOutput() if not os.path.isfile(self.output_path): @@ -420,7 +492,7 @@ def get_data(self): else: output.append(CollectorOutputEntry(self.output_path, 'file')) if self.autoreport: - textfile = os.path.splitext(self.output_path)[0] + '.txt' + textfile: str = os.path.splitext(self.output_path)[0] + '.txt' if self.report_on_target: self.generate_report_on_target() self.target.pull(self.target_text_file, @@ -432,20 +504,26 @@ def get_data(self): self.view(self.output_path) return output - def get_stats(self, outfile): + def get_stats(self, outfile: str) -> Optional[Dict[int, + Dict[str, + Dict[str, Union[int, float]]]]]: + """ + get the processing statistics for the cpu + :param outfile: path to the output file + """ if not (self.functions and self.tracer is None): - return + return None if os.path.isdir(outfile): outfile = os.path.join(outfile, OUTPUT_PROFILE_FILE) # pylint: disable=protected-access - output = self.target._execute_util('ftrace_get_function_stats', - as_root=True) + output: str = self.target._execute_util('ftrace_get_function_stats', + as_root=True) - function_stats = {} + function_stats: Dict[int, Dict[str, Dict[str, Union[int, float]]]] = {} for line in output.splitlines(): # Match a new CPU dataset - match = CPU_RE.search(line) + match: Optional[Match[str]] = CPU_RE.search(line) if match: cpu_id = int(match.group(1)) function_stats[cpu_id] = {} @@ -456,13 +534,13 @@ def get_stats(self, outfile): if match: fname = match.group(1) function_stats[cpu_id][fname] = { - 'hits' : int(match.group(2)), - 'time' : float(match.group(3)), - 'avg' : float(match.group(4)), - 's_2' : float(match.group(5)), - } + 'hits': int(match.group(2)), + 'time': float(match.group(3)), + 'avg': float(match.group(4)), + 's_2': float(match.group(5)), + } self.logger.debug(" %s: %s", - fname, function_stats[cpu_id][fname]) + fname, function_stats[cpu_id][fname]) self.logger.debug("FTrace stats output [%s]...", outfile) with open(outfile, 'w') as fh: @@ -471,15 +549,23 @@ def get_stats(self, outfile): return function_stats - def report(self, binfile, destfile): + def report(self, binfile: str, destfile: str) -> None: + """ + Generate a textual report from a captured trace.dat file on the host. + + :param binfile: The path to the binary trace file. + :param destfile: The path to write the report. + :raises TargetStableError: If trace-cmd returns a non-zero exit code. + :raises HostError: If trace-cmd is not found on the host. + """ # To get the output of trace.dat, trace-cmd must be installed # This is done host-side because the generated file is very large try: - command = '{} report {} > {}'.format(self.host_binary, binfile, destfile) + command: str = '{} report {} > {}'.format(self.host_binary, binfile, destfile) self.logger.debug(command) process = subprocess.Popen(command, stderr=subprocess.PIPE, shell=True) - _, error = process.communicate() - error = error.decode(sys.stdout.encoding or 'utf-8', 'replace') + _, error_b = process.communicate() + error = error_b.decode(sys.stdout.encoding or 'utf-8', 'replace') if process.returncode: raise TargetStableError('trace-cmd returned non-zero exit code {}'.format(process.returncode)) if error: @@ -500,34 +586,52 @@ def report(self, binfile, destfile): except OSError: raise HostError('Could not find trace-cmd. Please make sure it is installed and is in PATH.') - def generate_report_on_target(self): - command = '{} report {} > {}'.format(self.target_binary, - self.target_output_file, - self.target_text_file) + def generate_report_on_target(self) -> None: + """ + generate report on target + """ + command: str = '{} report {} > {}'.format(self.target_binary, + self.target_output_file, + self.target_text_file) self.target.execute(command, timeout=TIMEOUT) - def view(self, binfile): + def view(self, binfile: str) -> None: + """ + KernelShark is a graphical front-end tool for visualizing trace data collected by trace-cmd. + It allows users to view and analyze kernel tracing data in a more intuitive and interactive way. + """ check_output('{} {}'.format(self.kernelshark, binfile), shell=True) - def teardown(self): - self.target.remove(self.target.path.join(self.target.working_directory, OUTPUT_TRACE_FILE)) + def teardown(self) -> None: + """ + Remove the trace.dat file from the target, cleaning up after data collection. + """ + self.target.remove(self.target.path.join(self.target.working_directory, OUTPUT_TRACE_FILE) if self.target.path else '') - def mark_start(self): + def mark_start(self) -> None: + """ + Write a start marker into the ftrace marker file. + """ self.target.write_value(self.marker_file, TRACE_MARKER_START, verify=False) - def mark_stop(self): + def mark_stop(self) -> None: + """ + Write a stop marker into the ftrace marker file. + """ self.target.write_value(self.marker_file, TRACE_MARKER_STOP, verify=False) -def _build_trace_events(events): - event_string = ' '.join(['-e {}'.format(e) for e in events]) +def _build_trace_events(events: List[str]) -> str: + event_string: str = ' '.join(['-e {}'.format(e) for e in events]) return event_string -def _build_trace_functions(functions): - function_string = " ".join(functions) + +def _build_trace_functions(functions: List[str]) -> str: + function_string: str = " ".join(functions) return function_string -def _build_graph_functions(functions, trace_children_functions): + +def _build_graph_functions(functions: List[str], trace_children_functions: bool) -> str: opt = 'g' if trace_children_functions else 'l' return ' '.join( '-{} {}'.format(opt, quote(f)) diff --git a/devlib/collector/logcat.py b/devlib/collector/logcat.py index 770c9054b..614511cf6 100644 --- a/devlib/collector/logcat.py +++ b/devlib/collector/logcat.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,19 +18,33 @@ from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.utils.android import LogcatMonitor +from typing import (cast, TYPE_CHECKING, List, Optional, + Union) +from io import TextIOWrapper +from tempfile import _TemporaryFileWrapper +if TYPE_CHECKING: + from devlib.target import AndroidTarget, Target + class LogcatCollector(CollectorBase): + """ + A collector that retrieves logs via `adb logcat` from an Android target. - def __init__(self, target, regexps=None, logcat_format=None): + :param target: The devlib Target (must be Android). + :param regexps: A list of regular expressions to filter log lines (optional). + :param logcat_format: The desired logcat output format (optional). + """ + def __init__(self, target: 'Target', regexps: Optional[List[str]] = None, + logcat_format: Optional[str] = None): super(LogcatCollector, self).__init__(target) self.regexps = regexps self.logcat_format = logcat_format - self.output_path = None - self._collecting = False - self._prev_log = None - self._monitor = None + self.output_path: Optional[str] = None + self._collecting: bool = False + self._prev_log: Optional[Union[TextIOWrapper, _TemporaryFileWrapper[str]]] = None + self._monitor: Optional[LogcatMonitor] = None - def reset(self): + def reset(self) -> None: """ Clear Collector data but do not interrupt collection """ @@ -40,39 +54,45 @@ def reset(self): if self._collecting: self._monitor.clear_log() elif self._prev_log: - os.remove(self._prev_log) + os.remove(cast(str, self._prev_log)) self._prev_log = None - def start(self): + def start(self) -> None: """ - Start collecting logcat lines + Start capturing logcat output. Raises RuntimeError if no output path is set. """ if self.output_path is None: raise RuntimeError("Output path was not set.") - self._monitor = LogcatMonitor(self.target, self.regexps, logcat_format=self.logcat_format) + self._monitor = LogcatMonitor(cast('AndroidTarget', self.target), self.regexps, logcat_format=self.logcat_format) if self._prev_log: # Append new data collection to previous collection - self._monitor.start(self._prev_log) + self._monitor.start(cast(str, self._prev_log)) else: self._monitor.start(self.output_path) self._collecting = True - def stop(self): + def stop(self) -> None: """ Stop collecting logcat lines """ if not self._collecting: raise RuntimeError('Logcat monitor not running, nothing to stop') - - self._monitor.stop() + if self._monitor: + self._monitor.stop() self._collecting = False - self._prev_log = self._monitor.logfile + self._prev_log = self._monitor.logfile if self._monitor else None - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Return a :class:`CollectorOutput` for the captured logcat data. + + :raises RuntimeError: If :attr:`output_path` is unset. + :return: A collector output referencing the logcat file. + """ if self.output_path is None: raise RuntimeError("No data collected.") return CollectorOutput([CollectorOutputEntry(self.output_path, 'file')]) diff --git a/devlib/collector/perf.py b/devlib/collector/perf.py index a1389967a..ed6129915 100644 --- a/devlib/collector/perf.py +++ b/devlib/collector/perf.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,25 +16,30 @@ import os import re import time -from past.builtins import basestring, zip +from past.builtins import zip from devlib.host import PACKAGE_BIN_DIRECTORY from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.utils.misc import ensure_file_directory_exists as _f +from typing import (cast, List, Dict, TYPE_CHECKING, Optional, + Union, Pattern) +from signal import Signals +if TYPE_CHECKING: + from devlib.target import Target -PERF_STAT_COMMAND_TEMPLATE = '{binary} {command} {options} {events} {sleep_cmd} > {outfile} 2>&1 ' -PERF_REPORT_COMMAND_TEMPLATE= '{binary} report {options} -i {datafile} > {outfile} 2>&1 ' -PERF_REPORT_SAMPLE_COMMAND_TEMPLATE= '{binary} report-sample {options} -i {datafile} > {outfile} ' -PERF_RECORD_COMMAND_TEMPLATE= '{binary} record {options} {events} -o {outfile}' +PERF_STAT_COMMAND_TEMPLATE: str = '{binary} {command} {options} {events} {sleep_cmd} > {outfile} 2>&1 ' +PERF_REPORT_COMMAND_TEMPLATE: str = '{binary} report {options} -i {datafile} > {outfile} 2>&1 ' +PERF_REPORT_SAMPLE_COMMAND_TEMPLATE: str = '{binary} report-sample {options} -i {datafile} > {outfile} ' +PERF_RECORD_COMMAND_TEMPLATE: str = '{binary} record {options} {events} -o {outfile}' -PERF_DEFAULT_EVENTS = [ +PERF_DEFAULT_EVENTS: List[str] = [ 'cpu-migrations', 'context-switches', ] -SIMPLEPERF_DEFAULT_EVENTS = [ +SIMPLEPERF_DEFAULT_EVENTS: List[str] = [ 'raw-cpu-cycles', 'raw-l1-dcache', 'raw-l1-dcache-refill', @@ -42,7 +47,8 @@ 'raw-instruction-retired', ] -DEFAULT_EVENTS = {'perf':PERF_DEFAULT_EVENTS, 'simpleperf':SIMPLEPERF_DEFAULT_EVENTS} +DEFAULT_EVENTS: Dict[str, List[str]] = {'perf': PERF_DEFAULT_EVENTS, 'simpleperf': SIMPLEPERF_DEFAULT_EVENTS} + class PerfCollector(CollectorBase): """ @@ -82,43 +88,55 @@ class PerfCollector(CollectorBase): Options can be obtained by running the following in the command line :: man perf-stat + + :param target: The devlib Target (rooted if on Android). + :param perf_type: Either 'perf' or 'simpleperf'. + :param command: The perf command to run (e.g. 'stat' or 'record'). + :param events: A list of events to collect. Defaults to built-in sets. + :param optionstring: Extra CLI options (a string or list of strings). + :param report_options: Additional options for ``perf report``. + :param run_report_sample: If True, run the ``report-sample`` subcommand. + :param report_sample_options: Additional options for ``report-sample``. + :param labels: Unique labels for each command or option set. + :param force_install: If True, reinstall perf even if it's already on the target. + :param validate_events: If True, verify that requested events are available. """ def __init__(self, - target, - perf_type='perf', - command='stat', - events=None, - optionstring=None, - report_options=None, - run_report_sample=False, - report_sample_options=None, - labels=None, - force_install=False, - validate_events=True): + target: 'Target', + perf_type: str = 'perf', + command: str = 'stat', + events: Optional[List[str]] = None, + optionstring: Optional[Union[str, List[str]]] = None, + report_options: Optional[str] = None, + run_report_sample: bool = False, + report_sample_options: Optional[str] = None, + labels: Optional[List[str]] = None, + force_install: bool = False, + validate_events: bool = True): super(PerfCollector, self).__init__(target) self.force_install = force_install self.labels = labels self.report_options = report_options self.run_report_sample = run_report_sample self.report_sample_options = report_sample_options - self.output_path = None + self.output_path: Optional[str] = None self.validate_events = validate_events # Validate parameters if isinstance(optionstring, list): - self.optionstrings = optionstring + self.optionstrings: List[str] = optionstring else: - self.optionstrings = [optionstring] + self.optionstrings = [optionstring] if optionstring else [] if perf_type in ['perf', 'simpleperf']: - self.perf_type = perf_type + self.perf_type: str = perf_type else: raise ValueError('Invalid perf type: {}, must be perf or simpleperf'.format(perf_type)) if not events: - self.events = DEFAULT_EVENTS[self.perf_type] + self.events: List[str] = DEFAULT_EVENTS[self.perf_type] else: self.events = events - if isinstance(self.events, basestring): + if isinstance(self.events, str): self.events = [self.events] if not self.labels: self.labels = ['perf_{}'.format(i) for i in range(len(self.optionstrings))] @@ -133,51 +151,71 @@ def __init__(self, if report_sample_options and (command != 'record'): raise ValueError('report_sample_options specified, but command is not record') - self.binary = self.target.get_installed(self.perf_type) + self.binary: str = self.target.get_installed(self.perf_type) if self.force_install or not self.binary: self.binary = self._deploy_perf() if self.validate_events: self._validate_events(self.events) - self.commands = self._build_commands() + self.commands: List[str] = self._build_commands() - def reset(self): + def reset(self) -> None: self.target.killall(self.perf_type, as_root=self.target.is_rooted) self.target.remove(self.target.get_workpath('TemporaryFile*')) - for label in self.labels: - filepath = self._get_target_file(label, 'data') - self.target.remove(filepath) - filepath = self._get_target_file(label, 'rpt') - self.target.remove(filepath) - filepath = self._get_target_file(label, 'rptsamples') - self.target.remove(filepath) - - def start(self): + if self.labels: + for label in self.labels: + filepath = self._get_target_file(label, 'data') + self.target.remove(filepath) + filepath = self._get_target_file(label, 'rpt') + self.target.remove(filepath) + filepath = self._get_target_file(label, 'rptsamples') + self.target.remove(filepath) + + def start(self) -> None: + """ + Start the perf command(s) in the background on the target. + """ for command in self.commands: self.target.background(command, as_root=self.target.is_rooted) - def stop(self): - self.target.killall(self.perf_type, signal='SIGINT', + def stop(self) -> None: + """ + Send SIGINT to terminate the perf tool, finalizing any data files. + """ + self.target.killall(self.perf_type, signal=cast(Signals, 'SIGINT'), as_root=self.target.is_rooted) if self.perf_type == "perf" and self.command == "stat": # perf doesn't transmit the signal to its sleep call so handled here: self.target.killall('sleep', as_root=self.target.is_rooted) # NB: we hope that no other "important" sleep is on-going - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: + """ + Define where perf data or reports will be stored on the host. + + :param output_path: A directory or file path for storing perf results. + """ self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Pull the perf data from the target to the host and optionally generate + textual reports. + + :raises RuntimeError: If :attr:`output_path` is unset. + :return: A collector output referencing the saved perf files. + """ if self.output_path is None: raise RuntimeError("Output path was not set.") output = CollectorOutput() - + if self.labels is None: + raise RuntimeError("labels not set") for label in self.labels: if self.command == 'record': self._wait_for_data_file_write(label, self.output_path) - path = self._pull_target_file_to_host(label, 'rpt', self.output_path) + path: str = self._pull_target_file_to_host(label, 'rpt', self.output_path) output.append(CollectorOutputEntry(path, 'file')) if self.run_report_sample: report_samples_path = self._pull_target_file_to_host(label, 'rptsamples', self.output_path) @@ -187,16 +225,25 @@ def get_data(self): output.append(CollectorOutputEntry(path, 'file')) return output - def _deploy_perf(self): - host_executable = os.path.join(PACKAGE_BIN_DIRECTORY, - self.target.abi, self.perf_type) + def _deploy_perf(self) -> str: + """ + install perf on target + """ + host_executable: str = os.path.join(PACKAGE_BIN_DIRECTORY, + cast(str, self.target.abi), self.perf_type) return self.target.install(host_executable) - def _get_target_file(self, label, extension): + def _get_target_file(self, label: str, extension: str) -> Optional[str]: + """ + get file path on target + """ return self.target.get_workpath('{}.{}'.format(label, extension)) - def _build_commands(self): - commands = [] + def _build_commands(self) -> List[str]: + """ + build perf commands + """ + commands: List[str] = [] for opts, label in zip(self.optionstrings, self.labels): if self.command == 'stat': commands.append(self._build_perf_stat_command(opts, self.events, label)) @@ -204,50 +251,78 @@ def _build_commands(self): commands.append(self._build_perf_record_command(opts, label)) return commands - def _build_perf_stat_command(self, options, events, label): - event_string = ' '.join(['-e {}'.format(e) for e in events]) - sleep_cmd = 'sleep 1000' if self.perf_type == 'perf' else '' - command = PERF_STAT_COMMAND_TEMPLATE.format(binary = self.binary, - command = self.command, - options = options or '', - events = event_string, - sleep_cmd = sleep_cmd, - outfile = self._get_target_file(label, 'out')) + def _build_perf_stat_command(self, options: str, events: List[str], label) -> str: + """ + Construct a perf stat command string. + + :param options: Additional perf stat options. + :param events: The list of events to measure. + :param label: A label to identify this command/run. + :return: A command string suitable for running on the target. + """ + event_string: str = ' '.join(['-e {}'.format(e) for e in events]) + sleep_cmd: str = 'sleep 1000' if self.perf_type == 'perf' else '' + command: str = PERF_STAT_COMMAND_TEMPLATE.format(binary=self.binary, + command=self.command, + options=options or '', + events=event_string, + sleep_cmd=sleep_cmd, + outfile=self._get_target_file(label, 'out')) return command - def _build_perf_report_command(self, report_options, label): + def _build_perf_report_command(self, report_options: Optional[str], label: str) -> str: + """ + Construct a perf stat command string. + + :param options: Additional perf stat options. + :param events: The list of events to measure. + :param label: A label to identify this command/run. + :return: A command string suitable for running on the target. + """ command = PERF_REPORT_COMMAND_TEMPLATE.format(binary=self.binary, options=report_options or '', datafile=self._get_target_file(label, 'data'), outfile=self._get_target_file(label, 'rpt')) return command - def _build_perf_report_sample_command(self, label): + def _build_perf_report_sample_command(self, label: str) -> str: + """ + build perf report sample command + """ command = PERF_REPORT_SAMPLE_COMMAND_TEMPLATE.format(binary=self.binary, - options=self.report_sample_options or '', - datafile=self._get_target_file(label, 'data'), - outfile=self._get_target_file(label, 'rptsamples')) + options=self.report_sample_options or '', + datafile=self._get_target_file(label, 'data'), + outfile=self._get_target_file(label, 'rptsamples')) return command - def _build_perf_record_command(self, options, label): - event_string = ' '.join(['-e {}'.format(e) for e in self.events]) - command = PERF_RECORD_COMMAND_TEMPLATE.format(binary=self.binary, - options=options or '', - events=event_string, - outfile=self._get_target_file(label, 'data')) + def _build_perf_record_command(self, options: Optional[str], label: str) -> str: + """ + build perf record command + """ + event_string: str = ' '.join(['-e {}'.format(e) for e in self.events]) + command: str = PERF_RECORD_COMMAND_TEMPLATE.format(binary=self.binary, + options=options or '', + events=event_string, + outfile=self._get_target_file(label, 'data')) return command - def _pull_target_file_to_host(self, label, extension, output_path): - target_file = self._get_target_file(label, extension) - host_relpath = os.path.basename(target_file) - host_file = _f(os.path.join(output_path, host_relpath)) + def _pull_target_file_to_host(self, label: str, extension: str, output_path: str) -> str: + """ + pull a file from target to host + """ + target_file: Optional[str] = self._get_target_file(label, extension) + host_relpath: str = os.path.basename(target_file or '') + host_file: str = _f(os.path.join(output_path, host_relpath)) self.target.pull(target_file, host_file) return host_file - def _wait_for_data_file_write(self, label, output_path): - data_file_finished_writing = False - max_tries = 80 - current_tries = 0 + def _wait_for_data_file_write(self, label: str, output_path: str) -> None: + """ + wait for file write operation by perf + """ + data_file_finished_writing: bool = False + max_tries: int = 80 + current_tries: int = 0 while not data_file_finished_writing: files = self.target.execute('cd {} && ls'.format(self.target.get_workpath(''))) # Perf stores data in tempory files whilst writing to data output file. Check if they have been removed. @@ -259,15 +334,18 @@ def _wait_for_data_file_write(self, label, output_path): self.logger.warning('''writing {}.data file took longer than expected, file may not have written correctly'''.format(label)) data_file_finished_writing = True - report_command = self._build_perf_report_command(self.report_options, label) + report_command: str = self._build_perf_report_command(self.report_options, label) self.target.execute(report_command) if self.run_report_sample: report_sample_command = self._build_perf_report_sample_command(label) self.target.execute(report_sample_command) - def _validate_events(self, events): - available_events_string = self.target.execute('{} list | {} cat'.format(self.perf_type, self.target.busybox)) - available_events = available_events_string.splitlines() + def _validate_events(self, events: List[str]) -> None: + """ + validate events against available perf events on target + """ + available_events_string: str = self.target.execute('{} list | {} cat'.format(self.perf_type, self.target.busybox)) + available_events: List[str] = available_events_string.splitlines() for available_event in available_events: if available_event == '': continue @@ -275,7 +353,7 @@ def _validate_events(self, events): available_events.append(available_event.split('OR')[1]) available_events[available_events.index(available_event)] = available_event.split()[0].strip() # Raw hex event codes can also be passed in that do not appear on perf/simpleperf list, prefixed with 'r' - raw_event_code_regex = re.compile(r"^r(0x|0X)?[A-Fa-f0-9]+$") + raw_event_code_regex: Pattern[str] = re.compile(r"^r(0x|0X)?[A-Fa-f0-9]+$") for event in events: if event in available_events or re.match(raw_event_code_regex, event): continue diff --git a/devlib/collector/perfetto.py b/devlib/collector/perfetto.py index c5070e03a..082107df4 100644 --- a/devlib/collector/perfetto.py +++ b/devlib/collector/perfetto.py @@ -1,4 +1,4 @@ -# Copyright 2023 ARM Limited +# Copyright 2023-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,9 @@ from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.exception import TargetStableError, HostError +from typing import TYPE_CHECKING, Optional +if TYPE_CHECKING: + from devlib.target import Target, BackgroundCommand OUTPUT_PERFETTO_TRACE = 'devlib-trace.perfetto-trace' @@ -53,29 +56,33 @@ class PerfettoCollector(CollectorBase): For more information consult the official documentation: https://perfetto.dev/docs/ + + :param target: The devlib Target. + :param config: Path to a Perfetto text config (proto) if any. + :param force_tracebox: If True, force usage of tracebox instead of native Perfetto. """ - def __init__(self, target, config=None, force_tracebox=False): + def __init__(self, target: 'Target', config: Optional[str] = None, force_tracebox: bool = False): super().__init__(target) - self.bg_cmd = None + self.bg_cmd: Optional['BackgroundCommand'] = None self.config = config - self.target_binary = 'perfetto' - target_output_path = self.target.working_directory + self.target_binary: str = 'perfetto' + target_output_path: Optional[str] = self.target.working_directory - install_tracebox = force_tracebox or (target.os in ['linux', 'android'] and not target.is_running('traced')) + install_tracebox: bool = force_tracebox or (target.os in ['linux', 'android'] and not target.is_running('traced')) # Install Perfetto through tracebox if install_tracebox: self.target_binary = 'tracebox' if not self.target.get_installed(self.target_binary): - host_executable = os.path.join(PACKAGE_BIN_DIRECTORY, - self.target.abi, self.target_binary) + host_executable: str = os.path.join(PACKAGE_BIN_DIRECTORY, + self.target.abi or '', self.target_binary) if not os.path.exists(host_executable): raise HostError("{} not found on the host".format(self.target_binary)) self.target.install(host_executable) # Use Android's built-in Perfetto elif target.os == 'android': - os_version = target.os_version['release'] + os_version: str = target.os_version['release'] if int(os_version) >= 9: # Android requires built-in Perfetto to write to this directory target_output_path = '/data/misc/perfetto-traces' @@ -83,11 +90,16 @@ def __init__(self, target, config=None, force_tracebox=False): if int(os_version) <= 10: target.execute('setprop persist.traced.enable 1') - self.target_output_file = target.path.join(target_output_path, OUTPUT_PERFETTO_TRACE) + self.target_output_file = target.path.join(target_output_path, OUTPUT_PERFETTO_TRACE) if target.path else '' + + def start(self) -> None: + """ + Start Perfetto tracing by feeding the config to the perfetto (or tracebox) binary. - def start(self): - cmd = "{} cat {} | {} --txt -c - -o {}".format( - quote(self.target.busybox), quote(self.config), quote(self.target_binary), quote(self.target_output_file) + :raises TargetStableError: If perfetto/tracebox cannot be started on the target. + """ + cmd: str = "{} cat {} | {} --txt -c - -o {}".format( + quote(self.target.busybox or ''), quote(self.config or ''), quote(self.target_binary), quote(self.target_output_file) ) # start tracing if self.bg_cmd is None: @@ -95,17 +107,32 @@ def start(self): else: raise TargetStableError('Perfetto collector is not re-entrant') - def stop(self): - # stop tracing - self.bg_cmd.cancel() - self.bg_cmd = None - - def set_output(self, output_path): + def stop(self) -> None: + """ + Stop Perfetto tracing and finalize the trace file. + """ + if self.bg_cmd: + # stop tracing + self.bg_cmd.cancel() + self.bg_cmd = None + + def set_output(self, output_path: str) -> None: + """ + Specify where the trace file will be pulled on the host. + + :param output_path: The file path or directory on the host. + """ if os.path.isdir(output_path): output_path = os.path.join(output_path, os.path.basename(self.target_output_file)) self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Pull the trace file from the target and return a :class:`CollectorOutput`. + + :raises RuntimeError: If :attr:`output_path` is unset or if no trace file exists. + :return: A collector output referencing the Perfetto trace file. + """ if self.output_path is None: raise RuntimeError("Output path was not set.") if not self.target.file_exists(self.target_output_file): diff --git a/devlib/collector/screencapture.py b/devlib/collector/screencapture.py index 399227fc8..fb07d158f 100644 --- a/devlib/collector/screencapture.py +++ b/devlib/collector/screencapture.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,27 +22,40 @@ from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.exception import WorkerThreadError +from devlib.utils.misc import get_logger +from typing import TYPE_CHECKING, Optional, cast +if TYPE_CHECKING: + from devlib.target import Target class ScreenCapturePoller(threading.Thread): - - def __init__(self, target, period, timeout=30): + """ + A background thread that periodically captures screenshots from the target. + + :param target: The devlib Target. + :param period: Interval in seconds between captures. If None, the logic may differ. + :param timeout: Maximum time to wait for the poller thread to stop. + """ + def __init__(self, target: 'Target', period: Optional[float], timeout: int = 30): super(ScreenCapturePoller, self).__init__() self.target = target - self.logger = logging.getLogger('screencapture') + self.logger: logging.Logger = get_logger('screencapture') self.period = period self.timeout = timeout self.stop_signal = threading.Event() self.lock = threading.Lock() - self.last_poll = 0 - self.daemon = True - self.exc = None - self.output_path = None + self.last_poll: float = 0 + self.daemon: bool = True + self.exc: Optional[Exception] = None + self.output_path: Optional[str] = None - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def run(self): + def run(self) -> None: + """ + Continuously capture screenshots at the specified interval until stopped. + """ self.logger.debug('Starting screen capture polling') try: if self.output_path is None: @@ -52,13 +65,16 @@ def run(self): break with self.lock: current_time = time.time() - if (current_time - self.last_poll) >= self.period: + if (current_time - self.last_poll) >= cast(float, self.period): self.poll() time.sleep(0.5) except Exception: # pylint: disable=W0703 self.exc = WorkerThreadError(self.name, sys.exc_info()) - def stop(self): + def stop(self) -> None: + """ + Signal the thread to stop and wait for it to exit, up to :attr:`timeout`. + """ self.logger.debug('Stopping screen capture polling') self.stop_signal.set() self.join(self.timeout) @@ -67,34 +83,46 @@ def stop(self): if self.exc: raise self.exc # pylint: disable=E0702 - def poll(self): + def poll(self) -> None: self.last_poll = time.time() - self.target.capture_screen(os.path.join(self.output_path, "screencap_{ts}.png")) + self.target.capture_screen(os.path.join(self.output_path or '', "screencap_{ts}.png")) class ScreenCaptureCollector(CollectorBase): + """ + A collector that periodically captures screenshots from a target device. - def __init__(self, target, period=None): + :param target: The devlib Target. + :param period: Interval in seconds between captures. + """ + def __init__(self, target: 'Target', period: Optional[float] = None): super(ScreenCaptureCollector, self).__init__(target) - self._collecting = False - self.output_path = None + self._collecting: bool = False + self.output_path: Optional[str] = None self.period = period self.target = target - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def reset(self): + def reset(self) -> None: self._poller = ScreenCapturePoller(self.target, self.period) - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Return a :class:`CollectorOutput` referencing the directory of captured screenshots. + + :return: A collector output referencing the screenshot directory. + """ if self.output_path is None: raise RuntimeError("No data collected.") return CollectorOutput([CollectorOutputEntry(self.output_path, 'directory')]) - def start(self): + def start(self) -> None: """ - Start collecting the screenshots + Start the screen capture poller thread. + + :raises RuntimeError: If :attr:`output_path` is unset. """ if self.output_path is None: raise RuntimeError("Output path was not set.") @@ -102,9 +130,9 @@ def start(self): self._poller.start() self._collecting = True - def stop(self): + def stop(self) -> None: """ - Stop collecting the screenshots + Stop the screen capture poller thread. """ if not self._collecting: raise RuntimeError('Screen capture collector is not running, nothing to stop') diff --git a/devlib/collector/serial_trace.py b/devlib/collector/serial_trace.py index 7df9ab3ff..0d0e278af 100644 --- a/devlib/collector/serial_trace.py +++ b/devlib/collector/serial_trace.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,27 +18,41 @@ from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.utils.serial_port import get_connection +from typing import TextIO, cast, TYPE_CHECKING, Optional +from pexpect import fdpexpect +from serial import Serial +from io import BufferedWriter +if TYPE_CHECKING: + from devlib.target import Target class SerialTraceCollector(CollectorBase): - + """ + A collector that reads serial output and saves it to a file. + + :param target: The devlib Target. + :param serial_port: The serial port to open. + :param baudrate: The baud rate (bits per second). + :param timeout: A timeout for serial reads, in seconds. + """ @property - def collecting(self): + def collecting(self) -> bool: return self._collecting - def __init__(self, target, serial_port, baudrate, timeout=20): + def __init__(self, target: 'Target', serial_port: int, + baudrate: int, timeout: int = 20): super(SerialTraceCollector, self).__init__(target) self.serial_port = serial_port self.baudrate = baudrate self.timeout = timeout - self.output_path = None + self.output_path: Optional[str] = None - self._serial_target = None - self._conn = None - self._outfile_fh = None - self._collecting = False + self._serial_target: Optional[fdpexpect.fdspawn] = None + self._conn: Optional[Serial] = None + self._outfile_fh: Optional[BufferedWriter] = None + self._collecting: bool = False - def reset(self): + def reset(self) -> None: if self._collecting: raise RuntimeError("reset was called whilst collecting") @@ -46,24 +60,34 @@ def reset(self): self._outfile_fh.close() self._outfile_fh = None - def start(self): + def start(self) -> None: + """ + Open the serial connection and write all data to :attr:`output_path`. + + :raises RuntimeError: If already collecting or :attr:`output_path` is unset. + """ if self._collecting: raise RuntimeError("start was called whilst collecting") if self.output_path is None: raise RuntimeError("Output path was not set.") self._outfile_fh = open(self.output_path, 'wb') - start_marker = "-------- Starting serial logging --------\n" + start_marker: str = "-------- Starting serial logging --------\n" self._outfile_fh.write(start_marker.encode('utf-8')) self._serial_target, self._conn = get_connection(port=self.serial_port, baudrate=self.baudrate, timeout=self.timeout, - logfile=self._outfile_fh, - init_dtr=0) + logfile=cast(TextIO, self._outfile_fh), + init_dtr=False) self._collecting = True - def stop(self): + def stop(self) -> None: + """ + Close the serial connection and finalize the log file. + + :raises RuntimeError: If not currently collecting. + """ if not self._collecting: raise RuntimeError("stop was called whilst not collecting") @@ -71,25 +95,33 @@ def stop(self): # do something so that it interacts with the serial device, # and hence updates the logfile. try: - self._serial_target.expect(".", timeout=1) + if self._serial_target: + self._serial_target.expect(".", timeout=1) except TIMEOUT: pass - - self._serial_target.close() + if self._serial_target: + self._serial_target.close() del self._conn - stop_marker = "-------- Stopping serial logging --------\n" - self._outfile_fh.write(stop_marker.encode('utf-8')) - self._outfile_fh.flush() - self._outfile_fh.close() - self._outfile_fh = None + stop_marker: str = "-------- Stopping serial logging --------\n" + if self._outfile_fh: + self._outfile_fh.write(stop_marker.encode('utf-8')) + self._outfile_fh.flush() + self._outfile_fh.close() + self._outfile_fh = None self._collecting = False - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Return a :class:`CollectorOutput` referencing the saved serial log file. + + :raises RuntimeError: If :attr:`output_path` is unset. + :return: A collector output referencing the serial log file. + """ if self._collecting: raise RuntimeError("get_data was called whilst collecting") if self.output_path is None: diff --git a/devlib/collector/systrace.py b/devlib/collector/systrace.py index 4e29cf11a..3b7374e32 100644 --- a/devlib/collector/systrace.py +++ b/devlib/collector/systrace.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,9 +21,12 @@ from devlib.exception import TargetStableError, HostError import devlib.utils.android from devlib.utils.misc import memoized +from typing import TYPE_CHECKING, List, Optional, Union, TextIO +from subprocess import Popen +if TYPE_CHECKING: + from devlib.target import AndroidTarget - -DEFAULT_CATEGORIES = [ +DEFAULT_CATEGORIES: List[str] = [ 'gfx', 'view', 'sched', @@ -31,6 +34,7 @@ 'idle' ] + class SystraceCollector(CollectorBase): """ A trace collector based on Systrace @@ -38,50 +42,48 @@ class SystraceCollector(CollectorBase): For more details, see https://developer.android.com/studio/command-line/systrace :param target: Devlib target - :type target: AndroidTarget :param outdir: Working directory to use on the host - :type outdir: str :param categories: Systrace categories to trace. See `available_categories` - :type categories: list(str) :param buffer_size: Buffer size in kb - :type buffer_size: int :param strict: Raise an exception if any of the requested categories are not available - :type strict: bool """ @property @memoized - def available_categories(self): - lines = subprocess.check_output( - [self.systrace_binary, '-l'], universal_newlines=True + def available_categories(self) -> List[str]: + """ + list of available categories + """ + lines: List[str] = subprocess.check_output( + [self.systrace_binary or '', '-l'], universal_newlines=True ).splitlines() return [line.split()[0] for line in lines if line] - def __init__(self, target, - categories=None, - buffer_size=None, - strict=False): + def __init__(self, target: 'AndroidTarget', + categories: Optional[str] = None, + buffer_size: Optional[int] = None, + strict: bool = False): super(SystraceCollector, self).__init__(target) - self.categories = categories or DEFAULT_CATEGORIES + self.categories: Union[str, List[str]] = categories or DEFAULT_CATEGORIES self.buffer_size = buffer_size - self.output_path = None + self.output_path: Optional[str] = None - self._systrace_process = None - self._outfile_fh = None + self._systrace_process: Optional[Popen] = None + self._outfile_fh: Optional[TextIO] = None # Try to find a systrace binary - self.systrace_binary = None + self.systrace_binary: Optional[str] = None - platform_tools = devlib.utils.android.platform_tools - systrace_binary_path = os.path.join(platform_tools, 'systrace', 'systrace.py') + platform_tools: str = devlib.utils.android.platform_tools # type: ignore + systrace_binary_path: str = os.path.join(platform_tools, 'systrace', 'systrace.py') if not os.path.isfile(systrace_binary_path): raise HostError('Could not find any systrace binary under {}'.format(platform_tools)) @@ -90,7 +92,7 @@ def __init__(self, target, # Filter the requested categories for category in self.categories: if category not in self.available_categories: - message = 'Category [{}] not available for tracing'.format(category) + message: str = 'Category [{}] not available for tracing'.format(category) if strict: raise TargetStableError(message) self.logger.warning(message) @@ -102,11 +104,14 @@ def __init__(self, target, def __del__(self): self.reset() - def _build_cmd(self): - self._outfile_fh = open(self.output_path, 'w') + def _build_cmd(self) -> None: + """ + build command + """ + self._outfile_fh = open(self.output_path or '', 'w') # pylint: disable=attribute-defined-outside-init - self.systrace_cmd = 'python2 -u {} -o {} -e {}'.format( + self.systrace_cmd: str = 'python2 -u {} -o {} -e {}'.format( self.systrace_binary, self._outfile_fh.name, self.target.adb_name @@ -117,11 +122,14 @@ def _build_cmd(self): self.systrace_cmd += ' {}'.format(' '.join(self.categories)) - def reset(self): + def reset(self) -> None: if self._systrace_process: self.stop() - def start(self): + def start(self) -> None: + """ + Start systrace, typically running a systrace command in the background. + """ if self._systrace_process: raise RuntimeError("Tracing is already underway, call stop() first") if self.output_path is None: @@ -138,9 +146,13 @@ def start(self): shell=True, universal_newlines=True ) - self._systrace_process.stdout.read(1) + if self._systrace_process.stdout: + self._systrace_process.stdout.read(1) - def stop(self): + def stop(self) -> None: + """ + Stop systrace and finalize the trace file. + """ if not self._systrace_process: raise RuntimeError("No tracing to stop, call start() first") @@ -152,10 +164,16 @@ def stop(self): self._outfile_fh.close() self._outfile_fh = None - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Pull the trace HTML (or raw data) from the target and return a + :class:`CollectorOutput`. + + :return: A collector output referencing the systrace file. + """ if self._systrace_process: raise RuntimeError("Tracing is underway, call stop() first") if self.output_path is None: diff --git a/devlib/connection.py b/devlib/connection.py index 99055a3c1..352631e9b 100644 --- a/devlib/connection.py +++ b/devlib/connection.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,17 +22,66 @@ import subprocess import threading import time -import logging import select import fcntl -from devlib.utils.misc import InitCheckpoint, memoized +from devlib.utils.misc import InitCheckpoint, memoized, get_logger +from devlib.utils.annotation_helpers import SubprocessCommand +from typing import (Optional, TYPE_CHECKING, Set, + Tuple, IO, Dict, List, Union, + Callable, cast) +from collections.abc import Generator +from typing_extensions import Protocol, Literal + +if TYPE_CHECKING: + from signal import Signals + from subprocess import Popen + from threading import Lock, Thread, Event + from logging import Logger + from paramiko.channel import Channel + from paramiko.sftp_client import SFTPClient + from scp import SCPClient + + +class HasInitialized(Protocol): + """ + Protocol indicating that the object includes an ``initialized`` property + and a ``close()`` method. Used to ensure safe clean-up in destructors. + + :ivar initialized: ``True`` if the object finished initializing successfully, + otherwise ``False`` if initialization failed or is incomplete. + :vartype initialized: bool + """ + initialized: bool + + # other functions referred by the object with the initialized property + def close(self) -> None: + """ + Close method expected on objects that provide ``initialized``. + """ + ... + _KILL_TIMEOUT = 3 -def _popen_communicate(bg, popen, input, timeout): +def _popen_communicate(bg: 'BackgroundCommand', popen: 'Popen', input: bytes, + timeout: Optional[int]) -> Tuple[Optional[bytes], Optional[bytes]]: + """ + Wrapper around ``popen.communicate(...)`` to handle timeouts and + cancellation of a background command. + + :param bg: The associated :class:`BackgroundCommand` object that may be canceled. + :param popen: The :class:`subprocess.Popen` instance to communicate with. + :param input: Bytes to send to stdin. + :param timeout: The timeout in seconds or None for no timeout. + :return: A tuple (stdout, stderr) if the command completes successfully. + :raises subprocess.TimeoutExpired: If the command doesn't complete in time. + :raises subprocess.CalledProcessError: If the command exits with a non-zero return code. + """ try: + stdout: Optional[bytes] + stderr: Optional[bytes] stdout, stderr = popen.communicate(input=input, timeout=timeout) except subprocess.TimeoutExpired: bg.cancel() @@ -53,19 +102,39 @@ def _popen_communicate(bg, popen, input, timeout): class ConnectionBase(InitCheckpoint): """ Base class for all connections. + A :class:`Connection` abstracts an actual physical connection to a device. The + first connection is created when :func:`Target.connect` method is called. If a + :class:`~devlib.target.Target` is used in a multi-threaded environment, it will + maintain a connection for each thread in which it is invoked. This allows + the same target object to be used in parallel in multiple threads. + + :class:`Connection` s will be automatically created and managed by + :class:`~devlib.target.Target` s, so there is usually no reason to create one + manually. Instead, configuration for a :class:`Connection` is passed as + `connection_settings` parameter when creating a + :class:`~devlib.target.Target`. The connection to be used target is also + specified on instantiation by `conn_cls` parameter, though all concrete + :class:`~devlib.target.Target` implementations will set an appropriate + default, so there is typically no need to specify this explicitly. + + :param poll_transfers: If True, manage file transfers by polling for progress. + :param start_transfer_poll_delay: Delay in seconds before first checking a + file transfer's progress. + :param total_transfer_timeout: Cancel transfers if they exceed this many seconds. + :param transfer_poll_period: Interval (seconds) between transfer progress checks. """ def __init__( self, - poll_transfers=False, - start_transfer_poll_delay=30, - total_transfer_timeout=3600, - transfer_poll_period=30, + poll_transfers: bool = False, + start_transfer_poll_delay: int = 30, + total_transfer_timeout: int = 3600, + transfer_poll_period: int = 30, ): - self._current_bg_cmds = set() - self._closed = False - self._close_lock = threading.Lock() - self.busybox = None - self.logger = logging.getLogger('Connection') + self._current_bg_cmds: Set['BackgroundCommand'] = set() + self._closed: bool = False + self._close_lock: Lock = threading.Lock() + self.busybox: Optional[str] = None + self.logger: Logger = get_logger('Connection') self.transfer_manager = TransferManager( self, @@ -74,14 +143,17 @@ def __init__( transfer_poll_period=transfer_poll_period, ) if poll_transfers else NoopTransferManager() - - def cancel_running_command(self): - bg_cmds = set(self._current_bg_cmds) + def cancel_running_command(self) -> Optional[bool]: + """ + Cancel all active background commands tracked by this connection. + """ + bg_cmds: Set['BackgroundCommand'] = set(self._current_bg_cmds) for bg_cmd in bg_cmds: bg_cmd.cancel() + return None @abstractmethod - def _close(self): + def _close(self) -> None: """ Close the connection. @@ -90,11 +162,14 @@ def _close(self): be called from multiple threads at once. """ - def close(self): - - def finish_bg(): - bg_cmds = set(self._current_bg_cmds) - n = len(bg_cmds) + def close(self) -> None: + """ + Cancel any ongoing commands and finalize the connection. Safe to call multiple times, + does nothing after the first invocation. + """ + def finish_bg() -> None: + bg_cmds: Set['BackgroundCommand'] = set(self._current_bg_cmds) + n: int = len(bg_cmds) if n: self.logger.debug(f'Canceling {n} background commands before closing connection') for bg_cmd in bg_cmds: @@ -111,13 +186,35 @@ def finish_bg(): # Ideally, that should not be relied upon but that will improve the chances # of the connection being properly cleaned up when it's not in use anymore. - def __del__(self): + def __del__(self: HasInitialized): + """ + Destructor ensuring the connection is closed if not already. Only runs + if object initialization succeeded (initialized=True). + """ # Since __del__ will be called if an exception is raised in __init__ # (e.g. we cannot connect), we only run close() when we are sure # __init__ has completed successfully. if self.initialized: self.close() + @abstractmethod + def execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, + check_exit_code: bool = True, as_root: Optional[bool] = False, + strip_colors: bool = True, will_succeed: bool = False) -> str: + """ + Execute a shell command and return the combined stdout/stderr. + + :param command: Command string or SubprocessCommand detailing the command to run. + :param timeout: Timeout in seconds (None for no limit). + :param check_exit_code: If True, raise an error if exit code is non-zero. + :param as_root: If True, attempt to run with elevated privileges. + :param strip_colors: Remove ANSI color codes from output if True. + :param will_succeed: If True, interpret a failing command as a transient environment error. + :returns: The command's combined stdout and stderr. + :raises DevlibTransientError: If the command fails and is considered transient (will_succeed=True). + :raises DevlibStableError: If the command fails in a stable way (exit code != 0, or other error). + """ + class BackgroundCommand(ABC): """ @@ -126,9 +223,12 @@ class BackgroundCommand(ABC): Instances of this class can be used as context managers, with the same semantic as :class:`subprocess.Popen`. + + :param conn: The connection that owns this background command. """ - def __init__(self, conn, data_dir, cmd, as_root): + def __init__(self, conn: 'ConnectionBase', data_dir: str, + cmd: 'SubprocessCommand', as_root: Optional[bool]): self.conn = conn self._data_dir = data_dir self.as_root = as_root @@ -149,7 +249,8 @@ def __init__(self, conn, data_dir, cmd, as_root): conn._current_bg_cmds.add(self) @classmethod - def from_factory(cls, conn, cmd, as_root, make_init_kwargs): + def from_factory(cls, conn: 'ConnectionBase', cmd: 'SubprocessCommand', as_root: Optional[bool], + make_init_kwargs): cmd, data_dir = cls._with_data_dir(conn, cmd) return cls( conn=conn, @@ -159,7 +260,10 @@ def from_factory(cls, conn, cmd, as_root, make_init_kwargs): **make_init_kwargs(cmd), ) - def _deregister(self): + def _deregister(self) -> None: + """ + deregister the background command + """ try: self.conn._current_bg_cmds.remove(self) except KeyError: @@ -171,14 +275,14 @@ def _pid_file(self): @property @memoized - def _targeted_pid(self): + def _targeted_pid(self) -> int: """ PID of the process pointed at by ``devlib-signal-target`` command. """ - path = quote(self._pid_file) - busybox = quote(self.conn.busybox) + path: str = quote(self._pid_file) + busybox: str = quote(self.conn.busybox or '') - def execute(cmd): + def execute(cmd: 'SubprocessCommand'): return self.conn.execute(cmd, as_root=self.as_root) while self.poll() is None: @@ -193,35 +297,34 @@ def execute(cmd): # We got a partial write in the PID file continue - raise ValueError(f'The background commmand did not use devlib-signal-target wrapper to designate which command should be the target of signals') + raise ValueError('The background commmand did not use devlib-signal-target wrapper to designate which command should be the target of signals') @classmethod - def _with_data_dir(cls, conn, cmd): - busybox = quote(conn.busybox) + def _with_data_dir(cls, conn: 'ConnectionBase', cmd: 'SubprocessCommand'): + busybox = quote(conn.busybox or '') data_dir = conn.execute(f'{busybox} mktemp -d').strip() - cmd = f'_DEVLIB_BG_CMD_DATA_DIR={data_dir} exec {busybox} sh -c {quote(cmd)}' + cmd = f'_DEVLIB_BG_CMD_DATA_DIR={data_dir} exec {busybox} sh -c {quote(cast(str, cmd))}' return cmd, data_dir - def _cleanup_data_dir(self): + def _cleanup_data_dir(self) -> None: path = quote(self._data_dir) - busybox = quote(self.conn.busybox) + busybox = quote(self.conn.busybox or '') cmd = f'{busybox} rm -r {path} || true' self.conn.execute(cmd, as_root=self.as_root) - def send_signal(self, sig): + def send_signal(self, sig: 'Signals') -> None: """ Send a POSIX signal to the background command's process group ID (PGID). :param signal: Signal to send. - :type signal: signal.Signals """ - def execute(cmd): + def execute(cmd: 'SubprocessCommand'): return self.conn.execute(cmd, as_root=self.as_root) - def send(sig): - busybox = quote(self.conn.busybox) + def send(sig: 'Signals') -> None: + busybox: str = quote(self.conn.busybox or '') # If the command has already completed, we don't want to send a # signal to another process that might have gotten that PID in the # meantime. @@ -235,7 +338,7 @@ def send(sig): # Other signals require cooperation from the shell command # so that it points to a specific process using # devlib-signal-target - pid = self._targeted_pid + pid: int = self._targeted_pid execute(f'{busybox} kill -{sig.value} {pid}') try: return send(sig) @@ -243,16 +346,18 @@ def send(sig): # Deregister if the command has finished self.poll() - def kill(self): + def kill(self) -> None: """ Send SIGKILL to the background command. """ self.send_signal(signal.SIGKILL) - def cancel(self, kill_timeout=_KILL_TIMEOUT): + def cancel(self, kill_timeout: int = _KILL_TIMEOUT) -> None: """ Try to gracefully terminate the process by sending ``SIGTERM``, then waiting for ``kill_timeout`` to send ``SIGKILL``. + + :param kill_timeout: Seconds to wait between SIGTERM and SIGKILL. """ try: if self.poll() is None: @@ -261,30 +366,40 @@ def cancel(self, kill_timeout=_KILL_TIMEOUT): self._deregister() @abstractmethod - def _cancel(self, kill_timeout): + def _cancel(self, kill_timeout: int) -> None: """ - Method to override in subclasses to implement :meth:`cancel`. + Subclass-specific logic for :meth:`cancel`. Usually sends SIGTERM, waits, + then sends SIGKILL if needed. """ pass @abstractmethod - def _wait(self): + def _wait(self) -> int: + """ + Wait for the command to complete. Return its exit code. + """ pass - def wait(self): + def wait(self) -> int: """ - Block until the background command completes, and return its exit code. + Block until the command completes, returning the exit code. + + :returns: The exit code of the command. """ try: return self._wait() finally: self._deregister() - def communicate(self, input=b'', timeout=None): + def communicate(self, input: bytes = b'', timeout: Optional[int] = None) -> Tuple[Optional[bytes], Optional[bytes]]: """ - Block until the background command completes while reading stdout and stderr. - Return ``tuple(stdout, stderr)``. If the return code is non-zero, - raises a :exc:`subprocess.CalledProcessError` exception. + Write to stdin and read all data from stdout/stderr until the command exits. + + :param input: Bytes to send to stdin. + :param timeout: Max time to wait for the command to exit, or None if indefinite. + :returns: A tuple of (stdout, stderr) if the command exits cleanly. + :raises subprocess.TimeoutExpired: If the process runs past the timeout. + :raises subprocess.CalledProcessError: If the process exits with a non-zero code. """ try: return self._communicate(input=input, timeout=timeout) @@ -292,16 +407,25 @@ def communicate(self, input=b'', timeout=None): self.close() @abstractmethod - def _communicate(self, input, timeout): + def _communicate(self, input: bytes, timeout: Optional[int]) -> Tuple[Optional[bytes], Optional[bytes]]: + """ + Method to override in subclasses to implement :meth:`communicate`. + """ pass @abstractmethod - def _poll(self): + def _poll(self) -> Optional[int]: + """ + Method to override in subclasses to implement :meth:`poll`. + """ pass - def poll(self): + def poll(self) -> Optional[int]: """ - Return exit code if the command has exited, None otherwise. + Return the exit code if the command has finished, otherwise None. + Deregisters if the command is done. + + :returns: Exit code or None if ongoing. """ retcode = self._poll() if retcode is not None: @@ -310,28 +434,28 @@ def poll(self): @property @abstractmethod - def stdin(self): + def stdin(self) -> Optional[IO]: """ - File-like object connected to the background's command stdin. + A file-like object representing this command's standard input. May be None if unsupported. """ @property @abstractmethod - def stdout(self): + def stdout(self) -> Optional[IO]: """ - File-like object connected to the background's command stdout. + A file-like object representing this command's standard output. May be None. """ @property @abstractmethod - def stderr(self): + def stderr(self) -> Optional[IO]: """ - File-like object connected to the background's command stderr. + A file-like object representing this command's standard error. May be None. """ @property @abstractmethod - def pid(self): + def pid(self) -> int: """ Process Group ID (PGID) of the background command. @@ -343,14 +467,17 @@ def pid(self): """ @abstractmethod - def _close(self): + def _close(self) -> int: + """ + Subclass hook for final cleanup: close streams, wait for exit, return exit code. + """ pass - def close(self): + def close(self) -> int: """ - Close all opened streams and then wait for command completion. + Close any open streams and finalize the command. Return exit code. - :returns: Exit code of the command. + :returns: The command's final exit code. .. note:: If the command is writing to its stdout/stderr, it might be blocked on that and die when the streams are closed. @@ -370,10 +497,15 @@ def __exit__(self, *args, **kwargs): class PopenBackgroundCommand(BackgroundCommand): """ - :class:`subprocess.Popen`-based background command. + Runs a command via ``subprocess.Popen`` in the background. Signals are sent + to the process group. Streams are accessible via ``stdin``, ``stdout``, and ``stderr``. + + :param conn: The parent connection. + :param popen: The Popen object controlling the shell command. """ - def __init__(self, conn, data_dir, cmd, as_root, popen): + def __init__(self, conn: 'ConnectionBase', data_dir: str, cmd, + as_root: Optional[bool], popen: 'Popen'): super().__init__( conn=conn, data_dir=data_dir, @@ -383,31 +515,31 @@ def __init__(self, conn, data_dir, cmd, as_root, popen): self.popen = popen @property - def stdin(self): + def stdin(self) -> Optional[IO]: return self.popen.stdin @property - def stdout(self): + def stdout(self) -> Optional[IO]: return self.popen.stdout @property - def stderr(self): + def stderr(self) -> Optional[IO]: return self.popen.stderr @property - def pid(self): + def pid(self) -> int: return self.popen.pid - def _wait(self): + def _wait(self) -> int: return self.popen.wait() - def _communicate(self, input, timeout): + def _communicate(self, input: bytes, timeout: Optional[int]) -> Tuple[Optional[bytes], Optional[bytes]]: return _popen_communicate(self, self.popen, input, timeout) - def _poll(self): + def _poll(self) -> Optional[int]: return self.popen.poll() - def _cancel(self, kill_timeout): + def _cancel(self, kill_timeout: int) -> None: popen = self.popen os.killpg(os.getpgid(popen.pid), signal.SIGTERM) try: @@ -415,7 +547,7 @@ def _cancel(self, kill_timeout): except subprocess.TimeoutExpired: os.killpg(os.getpgid(popen.pid), signal.SIGKILL) - def _close(self): + def _close(self) -> int: self.popen.__exit__(None, None, None) return self.popen.returncode @@ -427,9 +559,22 @@ def __enter__(self): class ParamikoBackgroundCommand(BackgroundCommand): """ - :mod:`paramiko`-based background command. + Background command using a Paramiko :class:`Channel` for remote SSH-based execution. + Handles signals by running kill commands on the remote, using the PGID. + + :param conn: The SSH-based connection. + :param chan: The Paramiko channel running the remote command. + :param pid: Remote process group ID for signaling. + :param as_root: True if run with elevated privileges. + :param cmd: The shell command executed (for reference). + :param stdin: A file-like object to write into the remote stdin. + :param stdout: A file-like object for reading from the remote stdout. + :param stderr: A file-like object for reading from the remote stderr. + :param redirect_thread: A thread that captures data from the channel and writes to + stdout/stderr pipes. """ - def __init__(self, conn, data_dir, cmd, as_root, chan, pid, stdin, stdout, stderr, redirect_thread): + def __init__(self, conn: 'ConnectionBase', data_dir: str, cmd, as_root: Optional[bool], chan: 'Channel', + pid: int, stdin: IO, stdout: IO, stderr: IO, redirect_thread: 'Thread'): super().__init__( conn=conn, data_dir=data_dir, @@ -445,20 +590,23 @@ def __init__(self, conn, data_dir, cmd, as_root, chan, pid, stdin, stdout, stder self.redirect_thread = redirect_thread @property - def pid(self): + def pid(self) -> int: return self._pid - def _wait(self): + def _wait(self) -> int: status = self.chan.recv_exit_status() # Ensure that the redirection thread is finished copying the content # from paramiko to the pipe. self.redirect_thread.join() return status - def _communicate(self, input, timeout): + def _communicate(self, input: bytes, timeout: Optional[int]) -> Tuple[Optional[bytes], Optional[bytes]]: + """ + Implementation for reading from stdout/stderr, writing to stdin, + handling timeouts, etc. Raise an error if non-zero exit or timeout. + """ stdout = self._stdout stderr = self._stderr - stdin = self._stdin chan = self.chan # For some reason, file descriptors in the read-list of select() can @@ -469,21 +617,21 @@ def _communicate(self, input, timeout): for s in (stdout, stderr): fcntl.fcntl(s.fileno(), fcntl.F_SETFL, os.O_NONBLOCK) - out = {stdout: [], stderr: []} - ret = None - can_send = True + out: Dict[IO, List[bytes]] = {stdout: [], stderr: []} + ret: Optional[int] = None + can_send: bool = True - select_timeout = 1 + select_timeout: int = 1 if timeout is not None: select_timeout = min(select_timeout, 1) - def create_out(): + def create_out() -> Tuple[bytes, bytes]: return ( b''.join(out[stdout]), b''.join(out[stderr]) ) - start = time.monotonic() + start: float = time.monotonic() while ret is None: # Even if ret is not None anymore, we need to drain the streams @@ -495,11 +643,11 @@ def create_out(): raise subprocess.TimeoutExpired(self.cmd, timeout, _stdout, _stderr) can_send &= (not chan.closed) & bool(input) - wlist = [chan] if can_send else [] + wlist: List[Channel] = [chan] if can_send else [] if can_send and chan.send_ready(): try: - n = chan.send(input) + n: int = chan.send(input) # stdin might have been closed already except OSError: can_send = False @@ -509,7 +657,8 @@ def create_out(): if not input: # Send EOF on stdin chan.shutdown_write() - + rs: List[IO] + ws: List[IO] rs, ws, _ = select.select( [x for x in (stdout, stderr) if not x.closed], wlist, @@ -518,7 +667,7 @@ def create_out(): ) for r in rs: - chunk = r.read() + chunk: bytes = r.read() if chunk: out[r].append(chunk) @@ -534,7 +683,7 @@ def create_out(): else: return (_stdout, _stderr) - def _poll(self): + def _poll(self) -> Optional[int]: # Wait for the redirection thread to finish, otherwise we would # indicate the caller that the command is finished and that the streams # are safe to drain, but actually the redirection thread is not @@ -546,7 +695,7 @@ def _poll(self): else: return None - def _cancel(self, kill_timeout): + def _cancel(self, kill_timeout: int) -> None: self.send_signal(signal.SIGTERM) # Check if the command terminated quickly time.sleep(10e-3) @@ -557,24 +706,24 @@ def _cancel(self, kill_timeout): self.wait() @property - def stdin(self): + def stdin(self) -> Optional[IO]: return self._stdin @property - def stdout(self): + def stdout(self) -> Optional[IO]: return self._stdout @property - def stderr(self): + def stderr(self) -> Optional[IO]: return self._stderr - def _close(self): + def _close(self) -> int: for x in (self.stdin, self.stdout, self.stderr): if x is not None: x.close() - exit_code = self.wait() - thread = self.redirect_thread + exit_code: int = self.wait() + thread: Thread = self.redirect_thread if thread: thread.join() @@ -583,10 +732,17 @@ def _close(self): class AdbBackgroundCommand(BackgroundCommand): """ - ``adb``-based background command. + A background command launched through ADB. Manages signals by sending + kill commands on the remote Android device. + + :param conn: The ADB-based connection. + :param adb_popen: A subprocess.Popen object representing 'adb shell' or similar. + :param pid: Remote process group ID used for signals. + :param as_root: If True, signals are sent as root. """ - def __init__(self, conn, data_dir, cmd, as_root, adb_popen, pid): + def __init__(self, conn: 'ConnectionBase', data_dir: str, cmd: 'SubprocessCommand', + as_root: Optional[bool], adb_popen: 'Popen', pid: int): super().__init__( conn=conn, data_dir=data_dir, @@ -597,31 +753,32 @@ def __init__(self, conn, data_dir, cmd, as_root, adb_popen, pid): self._pid = pid @property - def stdin(self): + def stdin(self) -> Optional[IO]: return self.adb_popen.stdin @property - def stdout(self): + def stdout(self) -> Optional[IO]: return self.adb_popen.stdout @property - def stderr(self): + def stderr(self) -> Optional[IO]: return self.adb_popen.stderr @property - def pid(self): + def pid(self) -> int: return self._pid - def _wait(self): + def _wait(self) -> int: return self.adb_popen.wait() - def _communicate(self, input, timeout): + def _communicate(self, input: bytes, + timeout: Optional[int]) -> Tuple[Optional[bytes], Optional[bytes]]: return _popen_communicate(self, self.adb_popen, input, timeout) - def _poll(self): + def _poll(self) -> Optional[int]: return self.adb_popen.poll() - def _cancel(self, kill_timeout): + def _cancel(self, kill_timeout: int) -> None: self.send_signal(signal.SIGTERM) try: self.adb_popen.wait(timeout=kill_timeout) @@ -629,7 +786,7 @@ def _cancel(self, kill_timeout): self.send_signal(signal.SIGKILL) self.adb_popen.kill() - def _close(self): + def _close(self) -> int: self.adb_popen.__exit__(None, None, None) return self.adb_popen.returncode @@ -640,23 +797,51 @@ def __enter__(self): class TransferManager: - def __init__(self, conn, transfer_poll_period=30, start_transfer_poll_delay=30, total_transfer_timeout=3600): + """ + Monitors active file transfers (push or pull) in a background thread + and aborts them if they exceed a time limit or appear inactive. + + :param conn: The ConnectionBase owning this manager. + :param transfer_poll_period: Interval (seconds) between checks for activity. + :param start_transfer_poll_delay: Delay (seconds) before starting to poll a new transfer. + :param total_transfer_timeout: Cancel the transfer if it exceeds this duration. + """ + def __init__(self, conn: 'ConnectionBase', transfer_poll_period: int = 30, + start_transfer_poll_delay: int = 30, total_transfer_timeout: int = 3600): self.conn = conn self.transfer_poll_period = transfer_poll_period self.total_transfer_timeout = total_transfer_timeout self.start_transfer_poll_delay = start_transfer_poll_delay - self.logger = logging.getLogger('FileTransfer') + self.logger = get_logger('FileTransfer') @contextmanager - def manage(self, sources, dest, direction, handle): - excep = None - stop_thread = threading.Event() + def manage(self, sources: List[str], dest: str, + direction: Union[Literal['push'], Literal['pull']], + handle: 'TransferHandleBase') -> Generator: + """ + A context manager that spawns a thread to monitor file transfer progress. + If the transfer stalls or times out, it cancels the operation. + + :param sources: Paths being transferred. + :param dest: Destination path. + :param direction: 'push' or 'pull' for transfer direction. + :param handle: A TransferHandleBase for polling/canceling. + :raises TimeoutError: If the transfer times out. + """ + excep: Optional[TimeoutError] = None + stop_thread: Event = threading.Event() - def monitor(): + def monitor() -> None: + """ + thread to monitor the file transfer + """ nonlocal excep - def cancel(reason): + def cancel(reason: str) -> None: + """ + cancel the file transfer + """ self.logger.warning( f'Cancelling file transfer {sources} -> {dest} due to: {reason}' ) @@ -671,7 +856,7 @@ def cancel(reason): cancel(reason='transfer timed out') excep = TimeoutError(f'{direction}: {sources} -> {dest}') - m_thread = threading.Thread(target=monitor, daemon=True) + m_thread: Thread = threading.Thread(target=monitor, daemon=True) try: m_thread.start() yield self @@ -683,33 +868,51 @@ def cancel(reason): class NoopTransferManager: - def manage(self, *args, **kwargs): + """ + A manager that does nothing for transfers. Used if polling is disabled. + """ + def manage(self, *args, **kwargs) -> nullcontext: return nullcontext(self) class TransferHandleBase(ABC): - def __init__(self, manager): + """ + Abstract base for objects tracking a file transfer's progress and allowing cancellations. + + :param manager: The TransferManager that created this handle. + """ + def __init__(self, manager: 'TransferManager'): self.manager = manager @property def logger(self): + """ + get the logger for transfer manager + """ return self.manager.logger @abstractmethod - def isactive(self): + def isactive(self) -> bool: + """ + Check if the transfer still appears to be making progress (return True) + or if it is idle/complete (return False). + """ pass @abstractmethod - def cancel(self): + def cancel(self) -> None: + """ + cancel ongoing file transfer + """ pass class PopenTransferHandle(TransferHandleBase): - def __init__(self, popen, dest, direction, *args, **kwargs): + def __init__(self, popen, dest: str, direction: Union[Literal['push'], Literal['pull']], *args, **kwargs): super().__init__(*args, **kwargs) if direction == 'push': - sample_size = self._push_dest_size + sample_size: Callable[[str], Optional[int]] = self._push_dest_size elif direction == 'pull': sample_size = self._pull_dest_size else: @@ -721,35 +924,52 @@ def __init__(self, popen, dest, direction, *args, **kwargs): self.last_sample = 0 @staticmethod - def _pull_dest_size(dest): + def _pull_dest_size(dest: str) -> Optional[int]: + """ + Compute total size of a directory or file at the local ``dest`` path. + Returns None if it does not exist. + """ if os.path.isdir(dest): return sum( os.stat(os.path.join(dirpath, f)).st_size - for dirpath, _, fnames in os.walk(dest) - for f in fnames + for dirpath, _, fnames in os.walk(dest) + for f in fnames ) else: return os.stat(dest).st_size - def _push_dest_size(self, dest): - conn = self.manager.conn - cmd = '{} du -s -- {}'.format(quote(conn.busybox), quote(dest)) - out = conn.execute(cmd) - return int(out.split()[0]) - - def cancel(self): + def _push_dest_size(self, dest: str) -> Optional[int]: + """ + Compute total size of a directory or file on the remote device, + using busybox du if available. + """ + conn: 'ConnectionBase' = self.manager.conn + if conn.busybox: + cmd: str = '{} du -s -- {}'.format(quote(conn.busybox), quote(dest)) + out: str = conn.execute(cmd) + return int(out.split()[0]) + return None + + def cancel(self) -> None: + """ + Cancel the underlying background command, aborting the file transfer. + """ self.popen.terminate() - def isactive(self): + def isactive(self) -> bool: + """ + Check if the file size at the destination has grown since the last poll. + Returns True if so, otherwise might still be True if we can't read size. + """ try: - curr_size = self.sample_size() + curr_size: Optional[int] = self.sample_size() except Exception as e: self.logger.debug(f'File size polling failed: {e}') return True else: self.logger.debug(f'Polled file transfer, destination size: {curr_size}') if curr_size: - active = curr_size > self.last_sample + active: bool = curr_size > self.last_sample self.last_sample = curr_size return active # If the file is empty it will never grow in size, so we assume @@ -759,21 +979,32 @@ def isactive(self): class SSHTransferHandle(TransferHandleBase): + """ + SCP or SFTP-based file transfer handle that uses a callback to track progress. - def __init__(self, handle, *args, **kwargs): + :param handle: The SCPClient or SFTPClient controlling the file transfer. + """ + + def __init__(self, handle: Union['SCPClient', 'SFTPClient'], *args, **kwargs): super().__init__(*args, **kwargs) # SFTPClient or SSHClient self.handle = handle - self.progressed = False - self.transferred = 0 - self.to_transfer = 0 + self.progressed: bool = False + self.transferred: int = 0 + self.to_transfer: int = 0 - def cancel(self): + def cancel(self) -> None: + """ + Close the underlying SCP or SFTP client, presumably aborting the transfer. + """ self.handle.close() def isactive(self): + """ + Return True if we've seen progress since last poll, otherwise False. + """ progressed = self.progressed if progressed: self.progressed = False @@ -783,7 +1014,13 @@ def isactive(self): ) return progressed - def progress_cb(self, transferred, to_transfer): + def progress_cb(self, transferred: int, to_transfer: int) -> None: + """ + Callback to be called by the SCP/SFTP library on each progress update. + + :param transferred: Bytes transferred so far. + :param to_transfer: Total bytes to transfer, or 0 if unknown. + """ self.progressed = True self.transferred = transferred self.to_transfer = to_transfer diff --git a/devlib/exception.py b/devlib/exception.py index 33ef3c099..403fe2c19 100644 --- a/devlib/exception.py +++ b/devlib/exception.py @@ -1,4 +1,4 @@ -# Copyright 2013-2018 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +14,16 @@ # import subprocess +from typing import cast, Optional, List + +from devlib.utils.annotation_helpers import SubprocessCommand + class DevlibError(Exception): """Base class for all Devlib exceptions.""" - def __init__(self, *args): - message = args[0] if args else None + def __init__(self, *args) -> None: + message: Optional[object] = args[0] if args else None self._message = message @property @@ -73,18 +77,20 @@ class TargetStableError(TargetError, DevlibStableError): class TargetCalledProcessError(subprocess.CalledProcessError, TargetError): """Exception raised when a command executed on the target fails.""" - def __str__(self): + + def __str__(self) -> str: msg = super().__str__() - def decode(s): + + def decode(s: bytes) -> str: try: - s = s.decode() + st = s.decode() except AttributeError: - s = str(s) + st = str(s) - return s.strip() + return st.strip() if self.stdout is not None and self.stderr is None: - out = ['OUTPUT: {}'.format(decode(self.output))] + out: List[str] = ['OUTPUT: {}'.format(decode(self.output))] else: out = [ 'STDOUT: {}'.format(decode(self.output)) if self.output is not None else '', @@ -124,13 +130,13 @@ class TimeoutError(DevlibTransientError): programming error (e.g. not setting long enough timers), it is often due to some failure in the environment, and there fore should be classed as a "user error".""" - def __init__(self, command, output): - super(TimeoutError, self).__init__('Timed out: {}'.format(command)) + def __init__(self, command: Optional[SubprocessCommand], output: Optional[str]): + super(TimeoutError, self).__init__('Timed out: {}'.format(cast(str, command))) self.command = command self.output = output def __str__(self): - return '\n'.join([self.message, 'OUTPUT:', self.output or '']) + return '\n'.join([cast(str, self.message), 'OUTPUT:', self.output or '']) class WorkerThreadError(DevlibError): diff --git a/devlib/host.py b/devlib/host.py index a20711cc4..8b3727b5c 100644 --- a/devlib/host.py +++ b/devlib/host.py @@ -1,4 +1,4 @@ -# Copyright 2015-2024 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ import signal import shutil import subprocess -import logging import sys from getpass import getpass from shlex import quote @@ -24,12 +23,28 @@ from devlib.exception import ( TargetStableError, TargetTransientCalledProcessError, TargetStableCalledProcessError ) -from devlib.utils.misc import check_output +from devlib.utils.misc import check_output, get_logger from devlib.connection import ConnectionBase, PopenBackgroundCommand - +from typing import Optional, TYPE_CHECKING, cast, Union, List +from typing_extensions import Literal +if TYPE_CHECKING: + from devlib.platform import Platform + from devlib.utils.annotation_helpers import SubprocessCommand + from signal import Signals + from logging import Logger if sys.version_info >= (3, 8): - def copy_tree(src, dst): + def copy_tree(src: str, dst: str) -> None: + """ + Recursively copy an entire directory tree from ``src`` to ``dst``, + preserving the directory structure but **not** file metadata + (modification times, modes, etc.). If ``dst`` already exists, this + overwrites matching files. + + :param src: The source directory path. + :param dst: The destination directory path. + :raises OSError: If any file or directory within ``src`` cannot be copied. + """ from shutil import copy, copytree copytree( src, @@ -42,17 +57,40 @@ def copy_tree(src, dst): ) else: def copy_tree(src, dst): + """ + Recursively copy an entire directory tree from ``src`` to ``dst``, + preserving the directory structure but **not** file metadata + (modification times, modes, etc.). If ``dst`` already exists, this + overwrites matching files. + + :param src: The source directory path. + :param dst: The destination directory path. + :raises OSError: If any file or directory within ``src`` cannot be copied. + + .. note:: + This uses :func:`distutils.dir_util.copy_tree` under Python < 3.8, which + does not support ``dirs_exist_ok=True``. The behavior is effectively the same + for overwriting existing paths. + """ from distutils.dir_util import copy_tree # Mirror the behavior of all other targets which only copy the # content without metadata copy_tree(src, dst, preserve_mode=False, preserve_times=False) -PACKAGE_BIN_DIRECTORY = os.path.join(os.path.dirname(__file__), 'bin') +PACKAGE_BIN_DIRECTORY: str = os.path.join(os.path.dirname(__file__), 'bin') # pylint: disable=redefined-outer-name -def kill_children(pid, signal=signal.SIGKILL): +def kill_children(pid: int, signal: 'Signals' = signal.SIGKILL) -> None: + """ + Recursively kill all child processes of the specified process ID, then kill + the process itself with the given signal. + + :param pid: The process ID whose children (and itself) will be killed. + :param signal_: The signal to send (defaults to SIGKILL). + :raises ProcessLookupError: If any child process does not exist (e.g., race conditions). + """ with open('/proc/{0}/task/{0}/children'.format(pid), 'r') as fd: for cpid in map(int, fd.read().strip().split()): kill_children(cpid, signal) @@ -60,62 +98,147 @@ def kill_children(pid, signal=signal.SIGKILL): class LocalConnection(ConnectionBase): + """ + A connection to the local host, allowing the local system to be treated as a + devlib Target. Commands are run directly via :mod:`subprocess`, rather than + an SSH or ADB connection. + :param platform: A devlib Platform object for describing this local system + (e.g., CPU topology). If None, defaults may be used. + :param keep_password: If ``True``, cache the user’s sudo password in memory + after prompting. Defaults to True. + :param unrooted: If ``True``, assume the local system is non-root and do not + attempt root commands. This avoids prompting for a password. + :param password: Password for sudo. If provided, will not prompt the user. + :param timeout: A default timeout (in seconds) for connection-based operations. + """ name = 'local' host = 'localhost' + # pylint: disable=unused-argument + def __init__(self, platform: Optional['Platform'] = None, + keep_password: bool = True, unrooted: bool = False, + password: Optional[str] = None, timeout: Optional[int] = None): + """ + Initialize the LocalConnection instance. + """ + super().__init__() + self._connected_as_root: Optional[bool] = None + self.logger: Logger = get_logger('local_connection') + self.keep_password: bool = keep_password + self.unrooted: bool = unrooted + self.password: Optional[str] = password + @property - def connected_as_root(self): + def connected_as_root(self) -> Optional[bool]: + """ + Indicate whether the current user context is effectively 'root' (uid=0). + + :return: + - True if root + - False if not root + - None if undetermined + """ if self._connected_as_root is None: - result = self.execute('id', as_root=False) + result: str = self.execute('id', as_root=False) self._connected_as_root = 'uid=0(' in result return self._connected_as_root @connected_as_root.setter - def connected_as_root(self, state): - self._connected_as_root = state + def connected_as_root(self, state: Optional[bool]) -> None: + """ + Override the known 'connected_as_root' state, if needed. - # pylint: disable=unused-argument - def __init__(self, platform=None, keep_password=True, unrooted=False, - password=None, timeout=None): - super().__init__() - self._connected_as_root = None - self.logger = logging.getLogger('local_connection') - self.keep_password = keep_password - self.unrooted = unrooted - self.password = password + :param state: True if effectively root, False if not, or None if unknown. + """ + self._connected_as_root = state + def _copy_path(self, source: str, dest: str) -> None: + """ + Copy a single file or directory from ``source`` to ``dest``. If ``source`` + is a directory, it is copied recursively. - def _copy_path(self, source, dest): + :param source: The path to the file or directory on the local system. + :param dest: Destination path. + :raises OSError: If any part of the copy operation fails. + """ self.logger.debug('copying {} to {}'.format(source, dest)) if os.path.isdir(source): copy_tree(source, dest) else: shutil.copy(source, dest) - def _copy_paths(self, sources, dest): + def _copy_paths(self, sources: List[str], dest: str) -> None: + """ + Copy multiple paths (files or directories) to the same destination. + + :param sources: A tuple of file or directory paths to copy. + :param dest: The destination path, which may be a directory. + :raises OSError: If any part of a copy operation fails. + """ for source in sources: self._copy_path(source, dest) - def push(self, sources, dest, timeout=None, as_root=False): # pylint: disable=unused-argument + def push(self, sources: List[str], dest: str, timeout: Optional[int] = None, + as_root: bool = False) -> None: # pylint: disable=unused-argument + """ + Transfer a list of files **from the local system** to itself (no-op in some contexts). + In practice, this copies each file in ``sources`` to ``dest``. + + :param sources: List of file or directory paths on the local system. + :param dest: Destination path on the local system. + :param timeout: Timeout in seconds for each file copy; unused here (local copy). + :param as_root: If True, tries to escalate with sudo. Typically a no-op locally. + :raises TargetStableError: If the system is set to unrooted but as_root=True is used. + :raises OSError: If copying fails at any point. + """ self._copy_paths(sources, dest) - def pull(self, sources, dest, timeout=None, as_root=False): # pylint: disable=unused-argument + def pull(self, sources: List[str], dest: str, timeout: Optional[int] = None, + as_root: bool = False) -> None: # pylint: disable=unused-argument + """ + Transfer a list of files **from the local system** to the local system (similar to :meth:`push`). + + :param sources: list of paths on the local system. + :param dest: Destination directory or file path on local system. + :param timeout: Timeout in seconds; typically unused. + :param as_root: If True, attempts to use sudo for the copy, if not already root. + :raises TargetStableError: If the system is set to unrooted but as_root=True is used. + :raises OSError: If copying fails. + """ self._copy_paths(sources, dest) # pylint: disable=unused-argument - def execute(self, command, timeout=None, check_exit_code=True, - as_root=False, strip_colors=True, will_succeed=False): + def execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, + check_exit_code: bool = True, as_root: Optional[bool] = False, + strip_colors: bool = True, will_succeed: bool = False) -> str: + """ + Execute a command locally (via :func:`subprocess.check_output`), returning + combined stdout+stderr output. Optionally escalates privileges with sudo. + + :param command: The command to execute (string or SubprocessCommand). + :param timeout: Time in seconds after which the command is forcibly terminated. + :param check_exit_code: If True, raise an error on nonzero exit codes. + :param as_root: If True, attempt sudo unless already root. Fails if ``unrooted=True``. + :param strip_colors: If True, attempt to remove ANSI color codes from output. + (Not used in this local example.) + :param will_succeed: If True, treat a failing command as a transient error + rather than stable. + :return: The combined stdout+stderr of the command. + :raises TargetTransientCalledProcessError: If the command fails but is considered transient. + :raises TargetStableCalledProcessError: If the command fails and is considered stable. + :raises TargetStableError: If run as root is requested but unrooted is True. + """ self.logger.debug(command) - use_sudo = as_root and not self.connected_as_root + use_sudo: Optional[bool] = as_root and not self.connected_as_root if use_sudo: if self.unrooted: raise TargetStableError('unrooted') - password = self._get_password() + password: str = self._get_password() # Empty prompt with -p '' to avoid adding a leading space to the # output. - command = "echo {} | sudo -k -p '' -S -- sh -c {}".format(quote(password), quote(command)) - ignore = None if check_exit_code else 'all' + command = "echo {} | sudo -k -p '' -S -- sh -c {}".format(quote(password), quote(cast(str, command))) + ignore: Optional[Union[int, List[int], Literal['all']]] = None if check_exit_code else 'all' try: stdout, stderr = check_output(command, shell=True, timeout=timeout, ignore=ignore) except subprocess.CalledProcessError as e: @@ -133,14 +256,28 @@ def execute(self, command, timeout=None, check_exit_code=True, return stdout + stderr - def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + def background(self, command: 'SubprocessCommand', stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, as_root: Optional[bool] = False) -> PopenBackgroundCommand: + """ + Launch a command on the local system in the background, returning + a handle to manage its execution via :class:`PopenBackgroundCommand`. + + :param command: The command or SubprocessCommand to run. + :param stdout: File handle or constant (e.g. subprocess.PIPE) for capturing stdout. + :param stderr: File handle or constant for capturing stderr. + :param as_root: If True, attempt to run with sudo unless already root. + :return: A background command object that can be polled, waited on, or killed. + :raises TargetStableError: If unrooted is True but as_root is requested. + + .. note:: This **will block the connection** until the command completes. + """ if as_root and not self.connected_as_root: if self.unrooted: raise TargetStableError('unrooted') - password = self._get_password() + password: str = self._get_password() # Empty prompt with -p '' to avoid adding a leading space to the # output. - command = "echo {} | sudo -k -p '' -S -- sh -c {}".format(quote(password), quote(command)) + command = "echo {} | sudo -k -p '' -S -- sh -c {}".format(quote(password), quote(cast(str, command))) # Make sure to get a new PGID so PopenBackgroundCommand() can kill # all sub processes that could be started without troubles. @@ -167,22 +304,48 @@ def make_init_kwargs(command): make_init_kwargs=make_init_kwargs, ) - def _close(self): + def _close(self) -> None: + """ + Close the connection to the device. The :class:`Connection` object should not + be used after this method is called. There is no way to reopen a previously + closed connection, a new connection object should be created instead. + """ pass - def cancel_running_command(self): + def cancel_running_command(self) -> None: + """ + Cancel a running command (previously started with :func:`background`) and free up the connection. + It is valid to call this if the command has already terminated (or if no + command was issued), in which case this is a no-op. + """ pass - def wait_for_device(self, timeout=30): + def wait_for_device(self, timeout: int = 30) -> None: + """ + Wait for the local system to be 'available'. In practice, this is always a no-op + since we are already local. + :param timeout: Ignored. + """ return - def reboot_bootloader(self, timeout=30): + def reboot_bootloader(self, timeout: int = 30) -> None: + """ + Attempt to reboot into a bootloader mode. Not implemented for local usage. + + :param timeout: Time in seconds to wait for the operation to complete. + :raises NotImplementedError: Always, as local usage does not support bootloader reboots. + """ raise NotImplementedError() - def _get_password(self): + def _get_password(self) -> str: + """ + Prompt for the user's sudo password if not already cached. + + :return: The password string, either from cache or via user input. + """ if self.password: return self.password - password = getpass('sudo password:') + password: str = getpass('sudo password:') if self.keep_password: self.password = password return password diff --git a/devlib/instrument/__init__.py b/devlib/instrument/__init__.py index 6dca81cbe..6a0a38152 100644 --- a/devlib/instrument/__init__.py +++ b/devlib/instrument/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,31 +12,80 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging import collections - -from past.builtins import basestring - +from abc import abstractmethod from devlib.utils.csvutil import csvreader from devlib.utils.types import numeric from devlib.utils.types import identifier +from devlib.utils.misc import get_logger +from typing import (Dict, Optional, List, OrderedDict, + TYPE_CHECKING, Union, Callable, + Any, Tuple) +from collections.abc import Generator +if TYPE_CHECKING: + from devlib.target import Target # Channel modes describe what sort of measurement the instrument supports. # Values must be powers of 2 INSTANTANEOUS = 1 CONTINUOUS = 2 -MEASUREMENT_TYPES = {} # populated further down +MEASUREMENT_TYPES: Dict[str, 'MeasurementType'] = {} # populated further down class MeasurementType(object): - - def __init__(self, name, units, category=None, conversions=None): + """ + In order to make instruments easer to use, and to make it easier to swap them + out when necessary (e.g. change method of collecting power), a number of + standard measurement types are defined. This way, for example, power will + always be reported as "power" in Watts, and never as "pwr" in milliWatts. + Currently defined measurement types are + + + +-------------+-------------+---------------+ + | Name | Units | Category | + +=============+=============+===============+ + | count | count | | + +-------------+-------------+---------------+ + | percent | percent | | + +-------------+-------------+---------------+ + | time | seconds | time | + +-------------+-------------+---------------+ + | time_us | microseconds| time | + +-------------+-------------+---------------+ + | time_ms | milliseconds| time | + +-------------+-------------+---------------+ + | time_ns | nanoseconds | time | + +-------------+-------------+---------------+ + | temperature | degrees | thermal | + +-------------+-------------+---------------+ + | power | watts | power/energy | + +-------------+-------------+---------------+ + | voltage | volts | power/energy | + +-------------+-------------+---------------+ + | current | amps | power/energy | + +-------------+-------------+---------------+ + | energy | joules | power/energy | + +-------------+-------------+---------------+ + | tx | bytes | data transfer | + +-------------+-------------+---------------+ + | rx | bytes | data transfer | + +-------------+-------------+---------------+ + | tx/rx | bytes | data transfer | + +-------------+-------------+---------------+ + | fps | fps | ui render | + +-------------+-------------+---------------+ + | frames | frames | ui render | + +-------------+-------------+---------------+ + + """ + def __init__(self, name: str, units: Optional[str], + category: Optional[str] = None, conversions: Optional[Dict[str, Callable]] = None): self.name = name self.units = units self.category = category - self.conversions = {} + self.conversions: Dict[str, Callable] = {} if conversions is not None: for key, value in conversions.items(): if not callable(value): @@ -44,24 +93,48 @@ def __init__(self, name, units, category=None, conversions=None): raise ValueError(msg.format(type(value), value)) self.conversions[key] = value - def convert(self, value, to): - if isinstance(to, basestring) and to in MEASUREMENT_TYPES: + def convert(self, value: str, to: Union[str, 'MeasurementType']) -> Union[str, 'MeasurementType']: + if isinstance(to, str) and to in MEASUREMENT_TYPES: to = MEASUREMENT_TYPES[to] if not isinstance(to, MeasurementType): - msg = 'Unexpected conversion target: "{}"' + msg: str = 'Unexpected conversion target: "{}"' raise ValueError(msg.format(to)) if to.name == self.name: return value - if not to.name in self.conversions: + if to.name not in self.conversions: msg = 'No conversion from {} to {} available' raise ValueError(msg.format(self.name, to.name)) return self.conversions[to.name](value) - # pylint: disable=undefined-variable - def __cmp__(self, other): + def __lt__(self, other): + if isinstance(other, MeasurementType): + return self.name < other.name + return self.name < other + + def __le__(self, other): + if isinstance(other, MeasurementType): + return self.name <= other.name + return self.name <= other + + def __eq__(self, other): + if isinstance(other, MeasurementType): + return self.name == other.name + return self.name == other + + def __ne__(self, other): + if isinstance(other, MeasurementType): + return self.name != other.name + return self.name != other + + def __gt__(self, other): + if isinstance(other, MeasurementType): + return self.name > other.name + return self.name > other + + def __ge__(self, other): if isinstance(other, MeasurementType): - other = other.name - return cmp(self.name, other) + return self.name >= other.name + return self.name >= other def __str__(self): return self.name @@ -79,7 +152,7 @@ def __repr__(self): # to particular insturments (e.g. a particular method of mearuing power), instruments # must, where possible, resport their measurments formatted as on of the standard types # defined here. -_measurement_types = [ +_measurement_types: List[MeasurementType] = [ # For whatever reason, the type of measurement could not be established. MeasurementType('unknown', None), @@ -95,33 +168,33 @@ def __repr__(self): # processors that expect all times time be at a particular scale can automatically # covert without being familar with individual instruments. MeasurementType('time', 'seconds', 'time', - conversions={ - 'time_us': lambda x: x * 1e6, - 'time_ms': lambda x: x * 1e3, - 'time_ns': lambda x: x * 1e9, - } - ), + conversions={ + 'time_us': lambda x: x * 1e6, + 'time_ms': lambda x: x * 1e3, + 'time_ns': lambda x: x * 1e9, + } + ), MeasurementType('time_us', 'microseconds', 'time', - conversions={ - 'time': lambda x: x / 1e6, - 'time_ms': lambda x: x / 1e3, - 'time_ns': lambda x: x * 1e3, - } - ), + conversions={ + 'time': lambda x: x / 1e6, + 'time_ms': lambda x: x / 1e3, + 'time_ns': lambda x: x * 1e3, + } + ), MeasurementType('time_ms', 'milliseconds', 'time', - conversions={ - 'time': lambda x: x / 1e3, - 'time_us': lambda x: x * 1e3, - 'time_ns': lambda x: x * 1e6, - } - ), + conversions={ + 'time': lambda x: x / 1e3, + 'time_us': lambda x: x * 1e3, + 'time_ns': lambda x: x * 1e6, + } + ), MeasurementType('time_ns', 'nanoseconds', 'time', - conversions={ - 'time': lambda x: x / 1e9, - 'time_ms': lambda x: x / 1e6, - 'time_us': lambda x: x / 1e3, - } - ), + conversions={ + 'time': lambda x: x / 1e9, + 'time_ms': lambda x: x / 1e6, + 'time_us': lambda x: x / 1e3, + } + ), # Measurements related to thermals. MeasurementType('temperature', 'degrees', 'thermal'), @@ -150,23 +223,52 @@ class Measurement(object): __slots__ = ['value', 'channel'] @property - def name(self): + def name(self) -> str: + """ + name of the measurement + """ return '{}_{}'.format(self.channel.site, self.channel.kind) @property - def units(self): + def units(self) -> Optional[str]: + """ + Units in which measurement will be reported. + """ return self.channel.units - def __init__(self, value, channel): + def __init__(self, value: Union[int, float], channel: 'InstrumentChannel'): self.value = value self.channel = channel - # pylint: disable=undefined-variable - def __cmp__(self, other): + def __lt__(self, other): if hasattr(other, 'value'): - return cmp(self.value, other.value) - else: - return cmp(self.value, other) + return self.value < other.value + return self.value < other + + def __eq__(self, other): + if hasattr(other, 'value'): + return self.value == other.value + return self.value == other + + def __le__(self, other): + if hasattr(other, 'value'): + return self.value <= other.value + return self.value <= other + + def __ne__(self, other): + if hasattr(other, 'value'): + return self.value != other.value + return self.value != other + + def __gt__(self, other): + if hasattr(other, 'value'): + return self.value > other.value + return self.value > other + + def __ge__(self, other): + if hasattr(other, 'value'): + return self.value >= other.value + return self.value >= other def __str__(self): if self.units: @@ -179,44 +281,47 @@ def __str__(self): class MeasurementsCsv(object): - def __init__(self, path, channels=None, sample_rate_hz=None): + def __init__(self, path, channels: Optional[List['InstrumentChannel']] = None, + sample_rate_hz: Optional[float] = None): self.path = path self.channels = channels self.sample_rate_hz = sample_rate_hz if self.channels is None: self._load_channels() - headings = [chan.label for chan in self.channels] - self.data_tuple = collections.namedtuple('csv_entry', + headings = [chan.label for chan in self.channels] if self.channels else [] + + self.data_tuple = collections.namedtuple('csv_entry', # type:ignore map(identifier, headings)) - def measurements(self): + def measurements(self) -> List[List['Measurement']]: return list(self.iter_measurements()) - def iter_measurements(self): + def iter_measurements(self) -> Generator[List['Measurement'], None, None]: for row in self._iter_rows(): values = map(numeric, row) - yield [Measurement(v, c) for (v, c) in zip(values, self.channels)] + if self.channels: + yield [Measurement(v, c) for (v, c) in zip(values, self.channels)] - def values(self): + def values(self) -> List: return list(self.iter_values()) - def iter_values(self): + def iter_values(self) -> Generator[Tuple[Any], None, None]: for row in self._iter_rows(): values = list(map(numeric, row)) yield self.data_tuple(*values) - def _load_channels(self): - header = [] + def _load_channels(self) -> None: + header: List[str] = [] with csvreader(self.path) as reader: header = next(reader) self.channels = [] for entry in header: for mt in MEASUREMENT_TYPES: - suffix = '_{}'.format(mt) + suffix: str = '_{}'.format(mt) if entry.endswith(suffix): - site = entry[:-len(suffix)] - measure = mt + site: Optional[str] = entry[:-len(suffix)] + measure: str = mt break else: if entry in MEASUREMENT_TYPES: @@ -225,12 +330,12 @@ def _load_channels(self): else: site = entry measure = 'unknown' - - chan = InstrumentChannel(site, measure) - self.channels.append(chan) + if site: + chan = InstrumentChannel(site, measure) + self.channels.append(chan) # pylint: disable=stop-iteration-return - def _iter_rows(self): + def _iter_rows(self) -> Generator[List[str], None, None]: with csvreader(self.path) as reader: next(reader) # headings for row in reader: @@ -238,9 +343,51 @@ def _iter_rows(self): class InstrumentChannel(object): + """ + An :class:`InstrumentChannel` describes a single type of measurement that may + be collected by an :class:`~devlib.instrument.Instrument`. A channel is + primarily defined by a ``site`` and a ``measurement_type``. + + A ``site`` indicates where on the target a measurement is collected from + (e.g. a voltage rail or location of a sensor). + + A ``measurement_type`` is an instance of :class:`MeasurmentType` that + describes what sort of measurement this is (power, temperature, etc). Each + measurement type has a standard unit it is reported in, regardless of an + instrument used to collect it. + + A channel (i.e. site/measurement_type combination) is unique per instrument, + however there may be more than one channel associated with one site (e.g. for + both voltage and power). + + It should not be assumed that any site/measurement_type combination is valid. + The list of available channels can queried with + :func:`Instrument.list_channels()`. + + .. attribute:: InstrumentChannel.site + The name of the "site" from which the measurements are collected (e.g. voltage + rail, sensor, etc). + + """ @property - def label(self): + def label(self) -> str: + """ + Returns a label uniquely identifying the channel. + + This label is used to tag measurements and is constructed by + combining the channel's site and kind using the format: + '_'. + + If the site is not defined (i.e., None), only the kind is returned. + + Returns: + str: A string label for the channel. + + Example: + If site = "cluster0" and kind = "power", the label will be "cluster0_power". + If site = None and kind = "temperature", the label will be "temperature". + """ if self.site is not None: return '{}_{}'.format(self.site, self.kind) return self.kind @@ -248,14 +395,22 @@ def label(self): name = label @property - def kind(self): + def kind(self) -> str: + """ + A string indicating the type of measurement that will be collected. This is + the ``name`` of the :class:`MeasurmentType` associated with this channel. + """ return self.measurement_type.name @property - def units(self): + def units(self) -> Optional[str]: + """ + Units in which measurement will be reported. this is determined by the + underlying :class:`MeasurmentType`. + """ return self.measurement_type.units - def __init__(self, site, measurement_type, **attrs): + def __init__(self, site: str, measurement_type: Union[str, MeasurementType], **attrs): self.site = site if isinstance(measurement_type, MeasurementType): self.measurement_type = measurement_type @@ -277,39 +432,117 @@ def __str__(self): class Instrument(object): + """ + The ``Instrument`` API provide a consistent way of collecting measurements from + a target. Measurements are collected via an instance of a class derived from + :class:`~devlib.instrument.Instrument`. An ``Instrument`` allows collection of + measurement from one or more channels. An ``Instrument`` may support + ``INSTANTANEOUS`` or ``CONTINUOUS`` collection, or both. + + .. attribute:: Instrument.mode - mode = 0 + A bit mask that indicates collection modes that are supported by this + instrument. Possible values are: - def __init__(self, target): + :INSTANTANEOUS: The instrument supports taking a single sample via + ``take_measurement()``. + :CONTINUOUS: The instrument supports collecting measurements over a + period of time via ``start()``, ``stop()``, ``get_data()``, + and (optionally) ``get_raw`` methods. + + .. note:: It's possible for one instrument to support more than a single + mode. + + .. attribute:: Instrument.active_channels + + Channels that have been activated via ``reset()``. Measurements will only be + collected for these channels. + .. attribute:: Instrument.sample_rate_hz + + Sample rate of the instrument in Hz. Assumed to be the same for all channels. + + .. note:: This attribute is only provided by + :class:`~devlib.instrument.Instrument` s that + support ``CONTINUOUS`` measurement. + """ + mode: int = 0 + + def __init__(self, target: 'Target'): self.target = target - self.logger = logging.getLogger(self.__class__.__name__) - self.channels = collections.OrderedDict() - self.active_channels = [] - self.sample_rate_hz = None + self.logger = get_logger(self.__class__.__name__) + self.channels: OrderedDict[str, InstrumentChannel] = collections.OrderedDict() + self.active_channels: List[InstrumentChannel] = [] + self.sample_rate_hz: Optional[float] = None # channel management - def list_channels(self): + def list_channels(self) -> List[InstrumentChannel]: + """ + Returns a list of :class:`InstrumentChannel` instances that describe what + this instrument can measure on the current target. A channel is a combination + of a ``kind`` of measurement (power, temperature, etc) and a ``site`` that + indicates where on the target the measurement will be collected from. + """ return list(self.channels.values()) - def get_channels(self, measure): - if hasattr(measure, 'name'): - measure = measure.name + def get_channels(self, measure: Union[str, MeasurementType]): + """ + Returns channels for a particular ``measure`` type. A ``measure`` can be + either a string (e.g. ``"power"``) or a :class:`MeasurmentType` instance. + """ + if isinstance(measure, MeasurementType): + if hasattr(measure, 'name'): + measure = measure.name return [c for c in self.list_channels() if c.kind == measure] - def add_channel(self, site, measure, **attrs): + def add_channel(self, site: str, measure: Union[str, MeasurementType], **attrs) -> None: + """ + add channel to channels dict + """ chan = InstrumentChannel(site, measure, **attrs) self.channels[chan.label] = chan # initialization and teardown - def setup(self, *args, **kwargs): + def setup(self, *args, **kwargs) -> None: + """ + This will set up the instrument on the target. Parameters this method takes + are particular to subclasses (see documentation for specific instruments + below). What actions are performed by this method are also + instrument-specific. Usually these will be things like installing + executables, starting services, deploying assets, etc. Typically, this method + needs to be invoked at most once per reboot of the target (unless + ``teardown()`` has been called), but see documentation for the instrument + you're interested in. + """ pass - def teardown(self): + def teardown(self) -> None: + """ + Performs any required clean up of the instrument. This usually includes + removing temporary and raw files (if ``keep_raw`` is set to ``False`` on relevant + instruments), stopping services etc. + """ pass - def reset(self, sites=None, kinds=None, channels=None): + def reset(self, sites: Optional[List[str]] = None, + kinds: Optional[List[str]] = None, + channels: Optional[OrderedDict[str, InstrumentChannel]] = None) -> None: + """ + This is used to configure an instrument for collection. This must be invoked + before ``start()`` is called to begin collection. This methods sets the + ``active_channels`` attribute of the ``Instrument``. + + If ``channels`` is provided, it is a list of names of channels to enable and + ``sites`` and ``kinds`` must both be ``None``. + + Otherwise, if one of ``sites`` or ``kinds`` is provided, all channels + matching the given sites or kinds are enabled. If both are provided then all + channels of the given kinds at the given sites are enabled. + + If none of ``sites``, ``kinds`` or ``channels`` are provided then all + available channels are enabled. + """ if channels is not None: if sites is not None or kinds is not None: raise ValueError('sites and kinds should not be set if channels is set') @@ -317,36 +550,93 @@ def reset(self, sites=None, kinds=None, channels=None): try: self.active_channels = [self.channels[ch] for ch in channels] except KeyError as e: - msg = 'Unexpected channel "{}"; must be in {}' + msg: str = 'Unexpected channel "{}"; must be in {}' raise ValueError(msg.format(e, self.channels.keys())) elif sites is None and kinds is None: self.active_channels = sorted(self.channels.values(), key=lambda x: x.label) else: - if isinstance(sites, basestring): + if isinstance(sites, str): sites = [sites] - if isinstance(kinds, basestring): + if isinstance(kinds, str): kinds = [kinds] wanted = lambda ch: ((kinds is None or ch.kind in kinds) and - (sites is None or ch.site in sites)) + (sites is None or ch.site in sites)) self.active_channels = list(filter(wanted, self.channels.values())) # instantaneous - - def take_measurement(self): + @abstractmethod + def take_measurement(self) -> List[Measurement]: + """ + Take a single measurement from ``active_channels``. Returns a list of + :class:`Measurement` objects (one for each active channel). + + .. note:: This method is only implemented by + :class:`~devlib.instrument.Instrument's that + support ``INSTANTANEOUS`` measurement. + """ pass # continuous - def start(self): + def start(self) -> None: + """ + Starts collecting measurements from ``active_channels``. + + .. note:: This method is only implemented by + :class:`~devlib.instrument.Instrument` s that + support ``CONTINUOUS`` measurement. + """ pass - def stop(self): + def stop(self) -> None: + """ + Stops collecting measurements from ``active_channels``. Must be called after + :func:`start()`. + + .. note:: This method is only implemented by + :class:`~devlib.instrument.Instrument` s that + support ``CONTINUOUS`` measurement. + """ pass - # pylint: disable=no-self-use - def get_data(self, outfile): + @abstractmethod + def get_data(self, outfile: str) -> MeasurementsCsv: + """ + Write collected data into ``outfile``. Must be called after :func:`stop()`. + Data will be written in CSV format with a column for each channel and a row + for each sample. Column heading will be channel, labels in the form + ``_`` (see :class:`InstrumentChannel`). The order of the columns + will be the same as the order of channels in ``Instrument.active_channels``. + + If reporting timestamps, one channel must have a ``site`` named + ``"timestamp"`` and a ``kind`` of a :class:`MeasurmentType` of an appropriate + time unit which will be used, if appropriate, during any post processing. + + .. note:: Currently supported time units are seconds, milliseconds and + microseconds, other units can also be used if an appropriate + conversion is provided. + + This returns a :class:`MeasurementCsv` instance associated with the outfile + that can be used to stream :class:`Measurement` s lists (similar to what is + returned by ``take_measurement()``. + + .. note:: This method is only implemented by + :class:`~devlib.instrument.Instrument` s that + support ``CONTINUOUS`` measurement. + """ pass - def get_raw(self): + def get_raw(self) -> List[str]: + """ + Returns a list of paths to files containing raw output from the underlying + source(s) that is used to produce the data CSV. If no raw output is + generated or saved, an empty list will be returned. The format of the + contents of the raw files is entirely source-dependent. + + .. note:: This method is not guaranteed to return valid filepaths after the + :meth:`teardown` method has been invoked as the raw files may have + been deleted. Please ensure that copies are created manually + prior to calling :meth:`teardown` if the files are to be retained. + """ return [] diff --git a/devlib/instrument/acmecape.py b/devlib/instrument/acmecape.py index cfbcbe071..1d2f13197 100644 --- a/devlib/instrument/acmecape.py +++ b/devlib/instrument/acmecape.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. # -#pylint: disable=attribute-defined-outside-init +# pylint: disable=attribute-defined-outside-init import os import sys import time @@ -34,6 +34,7 @@ ${iio_capture} -n ${host} -b ${buffer_size} -c -f ${outfile} ${iio_device} """) + def _read_nonblock(pipe, size=1024): fd = pipe.fileno() flags = fcntl(fd, F_GETFL) @@ -93,7 +94,7 @@ def reset(self, sites=None, kinds=None, channels=None): iio_device=self.iio_device, outfile=self.raw_data_file ) - params = {k: quote(v) for k, v in params.items()} + params = {k: quote(v or '') for k, v in params.items()} self.command = IIOCAP_CMD_TEMPLATE.substitute(**params) self.logger.debug('ACME cape command: {}'.format(self.command)) @@ -115,7 +116,7 @@ def stop(self): if self.process.poll() is None: msg = 'Could not terminate iio-capture:\n{}' raise HostError(msg.format(output)) - if self.process.returncode != 15: # iio-capture exits with 15 when killed + if self.process.returncode != 15: # iio-capture exits with 15 when killed output += self.process.stdout.read().decode(sys.stdout.encoding or 'utf-8', 'replace') self.logger.info('ACME instrument encountered an error, ' 'you may want to try rebooting the ACME device:\n' diff --git a/devlib/instrument/arm_energy_probe.py b/devlib/instrument/arm_energy_probe.py index 80ef643da..1e697ec28 100644 --- a/devlib/instrument/arm_energy_probe.py +++ b/devlib/instrument/arm_energy_probe.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,6 +45,7 @@ from devlib.utils.parse_aep import AepParser + class ArmEnergyProbeInstrument(Instrument): """ Collects power traces using the ARM Energy Probe. @@ -68,23 +69,23 @@ class ArmEnergyProbeInstrument(Instrument): mode = CONTINUOUS - MAX_CHANNELS = 12 # 4 Arm Energy Probes + MAX_CHANNELS = 12 # 4 Arm Energy Probes def __init__(self, target, config_file='./config-aep', keep_raw=False): super(ArmEnergyProbeInstrument, self).__init__(target) self.arm_probe = which('arm-probe') if self.arm_probe is None: raise HostError('arm-probe must be installed on the host') - #todo detect is config file exist + # todo detect is config file exist self.attributes = ['power', 'voltage', 'current'] self.sample_rate_hz = 10000 self.config_file = config_file self.keep_raw = keep_raw self.parser = AepParser() - #TODO make it generic + # TODO make it generic topo = self.parser.topology_from_config(self.config_file) - for item in topo: + for item in topo or []: if item == 'time': self.add_channel('timestamp', 'time') else: @@ -103,9 +104,9 @@ def reset(self, sites=None, kinds=None, channels=None): def start(self): self.logger.debug(self.command) self.armprobe = subprocess.Popen(self.command, - stderr=self.output_fd_error, - preexec_fn=os.setpgrp, - shell=True) + stderr=self.output_fd_error, + preexec_fn=os.setpgrp, + shell=True) def stop(self): self.logger.debug("kill running arm-probe") @@ -132,7 +133,7 @@ def get_data(self, outfile): # pylint: disable=R0914 if len(row) < len(active_channels): continue # all data are in micro (seconds/watt) - new = [float(row[i])/1000000 for i in active_indexes] + new = [float(row[i]) / 1000000 for i in active_indexes] writer.writerow(new) self.output_fd_error.close() diff --git a/devlib/instrument/daq.py b/devlib/instrument/daq.py index 97c638fd8..fa32eca47 100644 --- a/devlib/instrument/daq.py +++ b/devlib/instrument/daq.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,11 +20,10 @@ from itertools import chain, zip_longest from devlib.host import PACKAGE_BIN_DIRECTORY -from devlib.instrument import Instrument, MeasurementsCsv, CONTINUOUS +from devlib.instrument import Instrument, MeasurementsCsv, CONTINUOUS, InstrumentChannel from devlib.exception import HostError from devlib.utils.csvutil import csvwriter, create_reader from devlib.utils.misc import unique - try: from daqpower.client import DaqClient from daqpower.config import DeviceConfiguration @@ -32,31 +31,36 @@ DaqClient = None DeviceConfiguration = None import_error_mesg = e.args[0] if e.args else str(e) +from typing import (TYPE_CHECKING, List, Union, Optional, Tuple, + cast, Dict, TextIO, Any, OrderedDict) +if TYPE_CHECKING: + from devlib.target import Target + from daqpower.server import DaqServer class DaqInstrument(Instrument): mode = CONTINUOUS - def __init__(self, target, resistor_values, # pylint: disable=R0914 - labels=None, - host='localhost', - port=45677, - device_id='Dev1', - v_range=2.5, - dv_range=0.2, - sample_rate_hz=10000, - channel_map=(0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23), - keep_raw=False, - time_as_clock_boottime=True + def __init__(self, target: 'Target', resistor_values: List[Union[int, str]], # pylint: disable=R0914 + labels: Optional[List[str]] = None, + host: str = 'localhost', + port: int = 45677, + device_id: str = 'Dev1', + v_range: float = 2.5, + dv_range: float = 0.2, + sample_rate_hz: int = 10000, + channel_map: Tuple = (0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23), + keep_raw: bool = False, + time_as_clock_boottime: bool = True ): # pylint: disable=no-member super(DaqInstrument, self).__init__(target) self.keep_raw = keep_raw - self._need_reset = True - self._raw_files = [] - self.tempdir = None - self.target_boottime_clock_at_start = 0.0 + self._need_reset: bool = True + self._raw_files: List[str] = [] + self.tempdir: Optional[str] = None + self.target_boottime_clock_at_start: float = 0.0 if DaqClient is None: raise HostError('Could not import "daqpower": {}'.format(import_error_mesg)) if labels is None: @@ -65,20 +69,20 @@ def __init__(self, target, resistor_values, # pylint: disable=R0914 raise ValueError('"labels" and "resistor_values" must be of the same length') self.daq_client = DaqClient(host, port) try: - devices = self.daq_client.list_devices() + devices: List[str] = cast('DaqServer', self.daq_client).list_devices() if device_id not in devices: msg = 'Device "{}" is not found on the DAQ server. Available devices are: "{}"' raise ValueError(msg.format(device_id, ', '.join(devices))) except Exception as e: raise HostError('Problem querying DAQ server: {}'.format(e)) - - self.device_config = DeviceConfiguration(device_id=device_id, - v_range=v_range, - dv_range=dv_range, - sampling_rate=sample_rate_hz, - resistor_values=resistor_values, - channel_map=channel_map, - labels=labels) + if DeviceConfiguration: + self.device_config = DeviceConfiguration(device_id=device_id, + v_range=v_range, + dv_range=dv_range, + sampling_rate=sample_rate_hz, + resistor_values=resistor_values, + channel_map=channel_map, + labels=labels) self.sample_rate_hz = sample_rate_hz self.time_as_clock_boottime = time_as_clock_boottime @@ -88,62 +92,67 @@ def __init__(self, target, resistor_values, # pylint: disable=R0914 self.add_channel(label, kind) if time_as_clock_boottime: - host_path = os.path.join(PACKAGE_BIN_DIRECTORY, self.target.abi, - 'get_clock_boottime') + host_path: str = os.path.join(PACKAGE_BIN_DIRECTORY, self.target.abi or '', + 'get_clock_boottime') self.clock_boottime_cmd = self.target.install_if_needed(host_path, search_system_binaries=False) - def calculate_boottime_offset(self): - time_before = time.time() - out = self.target.execute(self.clock_boottime_cmd) - time_after = time.time() + def calculate_boottime_offset(self) -> float: + """ + calculate boot time offset + """ + time_before: float = time.time() + out: str = self.target.execute(self.clock_boottime_cmd) + time_after: float = time.time() remote_clock_boottime = float(out) - propagation_delay = (time_after - time_before) / 2 - boottime_at_end = remote_clock_boottime + propagation_delay + propagation_delay: float = (time_after - time_before) / 2 + boottime_at_end: float = remote_clock_boottime + propagation_delay return time_after - boottime_at_end - def reset(self, sites=None, kinds=None, channels=None): + def reset(self, sites: Optional[List[str]] = None, + kinds: Optional[List[str]] = None, + channels: Optional[OrderedDict[str, InstrumentChannel]] = None) -> None: super(DaqInstrument, self).reset(sites, kinds, channels) - self.daq_client.close() - self.daq_client.configure(self.device_config) + cast('DaqServer', self.daq_client).close() + cast('DaqServer', self.daq_client).configure(self.device_config) self._need_reset = False self._raw_files = [] - def start(self): + def start(self) -> None: if self._need_reset: # Preserve channel order - self.reset(channels=self.channels.keys()) + self.reset(channels=cast(OrderedDict[str, InstrumentChannel], self.channels.keys())) if self.time_as_clock_boottime: target_boottime_offset = self.calculate_boottime_offset() time_start = time.time() - self.daq_client.start() + cast('DaqServer', self.daq_client).start() if self.time_as_clock_boottime: - time_end = time.time() + time_end: float = time.time() self.target_boottime_clock_at_start = (time_start + time_end) / 2 - target_boottime_offset - def stop(self): - self.daq_client.stop() + def stop(self) -> None: + cast('DaqServer', self.daq_client).stop() self._need_reset = True - def get_data(self, outfile): # pylint: disable=R0914 + def get_data(self, outfile: str) -> MeasurementsCsv: # pylint: disable=R0914 self.tempdir = tempfile.mkdtemp(prefix='daq-raw-') self.daq_client.get_data(self.tempdir) - raw_file_map = {} + raw_file_map: Dict[str, str] = {} for entry in os.listdir(self.tempdir): - site = os.path.splitext(entry)[0] - path = os.path.join(self.tempdir, entry) + site: str = os.path.splitext(entry)[0] + path: str = os.path.join(self.tempdir, entry) raw_file_map[site] = path self._raw_files.append(path) - active_sites = unique([c.site for c in self.active_channels]) - file_handles = [] + active_sites: List[str] = unique([c.site for c in self.active_channels]) + file_handles: List[TextIO] = [] try: - site_readers = {} + site_readers: Dict[str, Any] = {} for site in active_sites: try: site_file = raw_file_map[site] @@ -152,11 +161,11 @@ def get_data(self, outfile): # pylint: disable=R0914 file_handles.append(fh) except KeyError: if not site.startswith("Time"): - message = 'Could not get DAQ trace for {}; Obtained traces are in {}' + message: str = 'Could not get DAQ trace for {}; Obtained traces are in {}' raise HostError(message.format(site, self.tempdir)) # The first row is the headers - channel_order = ['Time_time'] + channel_order: List[str] = ['Time_time'] for site, reader in site_readers.items(): channel_order.extend(['{}_{}'.format(site, kind) for kind in next(reader)]) @@ -167,15 +176,15 @@ def _read_rows(): raw_row = list(chain.from_iterable(raw_row)) raw_row.insert(0, _read_rows.row_time_s) yield raw_row - _read_rows.row_time_s += 1.0 / self.sample_rate_hz + _read_rows.row_time_s += 1.0 / cast(float, self.sample_rate_hz) - _read_rows.row_time_s = self.target_boottime_clock_at_start + _read_rows.row_time_s = self.target_boottime_clock_at_start # type:ignore with csvwriter(outfile) as writer: - field_names = [c.label for c in self.active_channels] + field_names: List[str] = [c.label for c in self.active_channels] writer.writerow(field_names) for raw_row in _read_rows(): - row = [raw_row[channel_order.index(f)] for f in field_names] + row: List[str] = [raw_row[channel_order.index(f)] for f in field_names] writer.writerow(row) return MeasurementsCsv(outfile, self.active_channels, self.sample_rate_hz) @@ -183,11 +192,11 @@ def _read_rows(): for fh in file_handles: fh.close() - def get_raw(self): + def get_raw(self) -> List[str]: return self._raw_files - def teardown(self): - self.daq_client.close() + def teardown(self) -> None: + cast('DaqServer', self.daq_client).close() if not self.keep_raw: if self.tempdir and os.path.isdir(self.tempdir): shutil.rmtree(self.tempdir) diff --git a/devlib/instrument/frames.py b/devlib/instrument/frames.py index 402c48194..a5cd2c8ce 100644 --- a/devlib/instrument/frames.py +++ b/devlib/instrument/frames.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,74 +16,88 @@ import os from devlib.instrument import (Instrument, CONTINUOUS, - MeasurementsCsv, MeasurementType) + MeasurementsCsv, MeasurementType, + InstrumentChannel) from devlib.utils.rendering import (GfxinfoFrameCollector, SurfaceFlingerFrameCollector, SurfaceFlingerFrame, - read_gfxinfo_columns) + read_gfxinfo_columns, + FrameCollector) +from typing import (TYPE_CHECKING, Optional, Type, + OrderedDict, Any, List) +if TYPE_CHECKING: + from devlib.target import Target class FramesInstrument(Instrument): mode = CONTINUOUS - collector_cls = None + collector_cls: Optional[Type[FrameCollector]] = None - def __init__(self, target, collector_target, period=2, keep_raw=True): + def __init__(self, target: 'Target', collector_target: Any, + period: int = 2, keep_raw: bool = True): super(FramesInstrument, self).__init__(target) self.collector_target = collector_target self.period = period self.keep_raw = keep_raw - self.sample_rate_hz = 1 / self.period - self.collector = None - self.header = None - self._need_reset = True - self._raw_file = None + self.sample_rate_hz: float = 1 / self.period + self.collector: Optional[FrameCollector] = None + self.header: Optional[List[str]] = None + self._need_reset: bool = True + self._raw_file: Optional[str] = None self._init_channels() - def reset(self, sites=None, kinds=None, channels=None): + def reset(self, sites: Optional[List[str]] = None, + kinds: Optional[List[str]] = None, + channels: Optional[OrderedDict[str, InstrumentChannel]] = None) -> None: super(FramesInstrument, self).reset(sites, kinds, channels) - # pylint: disable=not-callable - self.collector = self.collector_cls(self.target, self.period, - self.collector_target, self.header) + if self.collector_cls: + # pylint: disable=not-callable + self.collector = self.collector_cls(self.target, self.period, + self.collector_target, self.header) # type: ignore self._need_reset = False self._raw_file = None - def start(self): + def start(self) -> None: if self._need_reset: self.reset() - self.collector.start() + if self.collector: + self.collector.start() - def stop(self): - self.collector.stop() + def stop(self) -> None: + if self.collector: + self.collector.stop() self._need_reset = True - def get_data(self, outfile): + def get_data(self, outfile: str) -> MeasurementsCsv: if self.keep_raw: self._raw_file = outfile + '.raw' - self.collector.process_frames(self._raw_file) - active_sites = [chan.label for chan in self.active_channels] - self.collector.write_frames(outfile, columns=active_sites) + if self.collector: + self.collector.process_frames(self._raw_file) + active_sites: List[str] = [chan.label for chan in self.active_channels] + if self.collector: + self.collector.write_frames(outfile, columns=active_sites) return MeasurementsCsv(outfile, self.active_channels, self.sample_rate_hz) - def get_raw(self): + def get_raw(self) -> List[str]: return [self._raw_file] if self._raw_file else [] - def _init_channels(self): + def _init_channels(self) -> None: raise NotImplementedError() - def teardown(self): + def teardown(self) -> None: if not self.keep_raw: - if os.path.isfile(self._raw_file): - os.remove(self._raw_file) + if os.path.isfile(self._raw_file or ''): + os.remove(self._raw_file or '') class GfxInfoFramesInstrument(FramesInstrument): - mode = CONTINUOUS + mode: int = CONTINUOUS collector_cls = GfxinfoFrameCollector - def _init_channels(self): - columns = read_gfxinfo_columns(self.target) + def _init_channels(self) -> None: + columns: List[str] = read_gfxinfo_columns(self.target) for entry in columns: if entry == 'Flags': self.add_channel('Flags', MeasurementType('flags', 'flags')) @@ -94,10 +108,10 @@ def _init_channels(self): class SurfaceFlingerFramesInstrument(FramesInstrument): - mode = CONTINUOUS + mode: int = CONTINUOUS collector_cls = SurfaceFlingerFrameCollector - def _init_channels(self): + def _init_channels(self) -> None: for field in SurfaceFlingerFrame._fields: # remove the "_time" from filed names to avoid duplication self.add_channel(field[:-5], 'time_us') diff --git a/devlib/instrument/hwmon.py b/devlib/instrument/hwmon.py index 7c1cb7d1a..16795eaf2 100644 --- a/devlib/instrument/hwmon.py +++ b/devlib/instrument/hwmon.py @@ -1,4 +1,4 @@ -# Copyright 2015-2017 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,15 +16,20 @@ from devlib.instrument import Instrument, Measurement, INSTANTANEOUS from devlib.exception import TargetStableError +from typing import (Dict, Tuple, Callable, Union, TYPE_CHECKING, + cast, List) +from devlib.module.hwmon import HwmonModule, HwmonSensor +if TYPE_CHECKING: + from devlib.target import Target class HwmonInstrument(Instrument): - name = 'hwmon' - mode = INSTANTANEOUS + name: str = 'hwmon' + mode: int = INSTANTANEOUS # sensor kind --> (meaure, standard unit conversion) - measure_map = { + measure_map: Dict[str, Tuple[str, Callable[[Union[int, float]], float]]] = { 'temp': ('temperature', lambda x: x / 1000), 'in': ('voltage', lambda x: x / 1000), 'curr': ('current', lambda x: x / 1000), @@ -32,16 +37,18 @@ class HwmonInstrument(Instrument): 'energy': ('energy', lambda x: x / 1000000), } - def __init__(self, target): + def __init__(self, target: 'Target'): if not hasattr(target, 'hwmon'): raise TargetStableError('Target does not support HWMON') super(HwmonInstrument, self).__init__(target) self.logger.debug('Discovering available HWMON sensors...') - for ts in self.target.hwmon.sensors: + for ts in cast(HwmonModule, self.target.hwmon).sensors: try: ts.get_file('input') - measure = self.measure_map.get(ts.kind)[0] + measure_map = self.measure_map.get(ts.kind) + if measure_map: + measure: str = measure_map[0] if measure: self.logger.debug('\tAdding sensor {}'.format(ts.name)) self.add_channel(_guess_site(ts), measure, sensor=ts) @@ -52,16 +59,16 @@ def __init__(self, target): self.logger.debug(message.format(ts.name)) continue - def take_measurement(self): - result = [] + def take_measurement(self) -> List[Measurement]: + result: List[Measurement] = [] for chan in self.active_channels: - convert = self.measure_map[chan.sensor.kind][1] - value = convert(chan.sensor.get('input')) + convert = self.measure_map[chan.sensor.kind][1] # type: ignore + value = convert(chan.sensor.get('input')) # type: ignore result.append(Measurement(value, chan)) return result -def _guess_site(sensor): +def _guess_site(sensor: HwmonSensor): """ HWMON does not specify a standard for labeling its sensors, or for device/item split (the implication is that each hwmon device a separate chip @@ -74,7 +81,7 @@ def _guess_site(sensor): # If no label has been specified for the sensor (in which case, it # defaults to the sensor's name), assume that the "site" of the sensor # is identified by the HWMON device - text = sensor.device.name + text: str = sensor.device.name else: # If a label has been specified, assume multiple sensors controlled by # the same device and the label identifies the site. diff --git a/devlib/module/__init__.py b/devlib/module/__init__.py index c450ba17e..473cb2efa 100644 --- a/devlib/module/__init__.py +++ b/devlib/module/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2014-2018 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging from inspect import isclass from devlib.exception import TargetStableError from devlib.utils.types import identifier -from devlib.utils.misc import walk_modules +from devlib.utils.misc import walk_modules, get_logger +from typing import (Optional, Dict, Union, Type, + TYPE_CHECKING, Any) +if TYPE_CHECKING: + from devlib.target import Target +_module_registry: Dict[str, Type['Module']] = {} -_module_registry = {} -def register_module(mod): +def register_module(mod: Type['Module']) -> None: + """ + Modules are specified on :class:`~devlib.target.Target` or + :class:`~devlib.platform.Platform` creation by name. In order to find the class + associated with the name, the module needs to be registered with ``devlib``. + This is accomplished by passing the module class into :func:`register_module` + method once it is defined. + + .. note:: If you're wiring a module to be included as part of ``devlib`` code + base, you can place the file with the module class under + ``devlib/modules/`` in the source and it will be automatically + enumerated. There is no need to explicitly register it in that case. + + The code snippet below illustrates an implementation of a hard reset function + for an "Acme" device. + + .. code:: python + + import os + from devlib import HardResetModule, register_module + + + class AcmeHardReset(HardResetModule): + + name = 'acme_hard_reset' + + def __call__(self): + # Assuming Acme board comes with a "reset-acme-board" utility + os.system('reset-acme-board {}'.format(self.target.name)) + + register_module(AcmeHardReset) + """ if not issubclass(mod, Module): raise ValueError('A module must subclass devlib.Module') @@ -39,34 +73,87 @@ def register_module(mod): class Module: - - name = None - kind = None - # This is the stage at which the module will be installed. Current valid - # stages are: - # 'early' -- installed when the Target is first created. This should be - # used for modules that do not rely on the main connection - # being established (usually because the commumnitcate with the - # target through some sorto of secondary connection, e.g. via - # serial). - # 'connected' -- installed when a connection to to the target has been - # established. This is the default. - # 'setup' -- installed after initial setup of the device has been performed. - # This allows the module to utilize assets deployed during the - # setup stage for example 'Busybox'. - stage = 'connected' + """ + Modules add additional functionality to the core :class:`~devlib.target.Target` + interface. Usually, it is support for specific subsystems on the target. Modules + are instantiated as attributes of the :class:`~devlib.target.Target` instance. + + Modules implement discrete, optional pieces of functionality ("optional" in the + sense that the functionality may or may not be present on the target device, or + that it may or may not be necessary for a particular application). + + Every module (ultimately) derives from :class:`devlib.module.Module` class. A + module must define the following class attributes: + + :name: A unique name for the module. This cannot clash with any of the existing + names and must be a valid Python identifier, but is otherwise free-form. + :kind: This identifies the type of functionality a module implements, which in + turn determines the interface implemented by the module (all modules of + the same kind must expose a consistent interface). This must be a valid + Python identifier, but is otherwise free-form, though, where possible, + one should try to stick to an already-defined kind/interface, lest we end + up with a bunch of modules implementing similar functionality but + exposing slightly different interfaces. + + .. note:: It is possible to omit ``kind`` when defining a module, in + which case the module's ``name`` will be treated as its + ``kind`` as well. + + :stage: This defines when the module will be installed into a + :class:`~devlib.target.Target`. Currently, the following values are + allowed: + + :connected: The module is installed after a connection to the target has + been established. This is the default. + :early: The module will be installed when a + :class:`~devlib.target.Target` is first created. This should be + used for modules that do not rely on a live connection to the + target. + :setup: The module will be installed after initial setup of the device + has been performed. This allows the module to utilize assets + deployed during the setup stage for example 'Busybox'. + + Additionally, a module must implement a static (or class) method :func:`probe`: + """ + name: Optional[str] = None + kind: Optional[str] = None + attr_name: Optional[str] = None + stage: str = 'connected' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: + """ + This method takes a :class:`~devlib.target.Target` instance and returns + ``True`` if this module is supported by that target, or ``False`` otherwise. + + .. note:: If the module ``stage`` is ``"early"``, this method cannot assume + that a connection has been established (i.e. it can only access + attributes of the Target that do not rely on a connection). + """ raise NotImplementedError() @classmethod - def install(cls, target, **params): - attr_name = cls.attr_name - installed = target._installed_modules + def install(cls, target: 'Target', **params: Type['Module']): + """ + The default installation method will create an instance of a module (the + :class:`~devlib.target.Target` instance being the sole argument) and assign it + to the target instance attribute named after the module's ``kind`` (or + ``name`` if ``kind`` is ``None``). + + It is possible to change the installation procedure for a module by overriding + the default :func:`install` method. The method must have the following + signature: + + .. method:: Module.install(cls, target, **kwargs) + + Install the module into the target instance. + """ + attr_name: Optional[str] = cls.attr_name + installed: Dict[str, 'Module'] = target._installed_modules try: - mod = installed[attr_name] + if attr_name: + mod: 'Module' = installed[attr_name] except KeyError: mod = cls(target, **params) mod.logger.debug(f'Installing module {cls.name}') @@ -79,8 +166,8 @@ def install(cls, target, **params): ): if name is not None: installed[name] = mod - - target._modules[cls.name] = params + if cls.name: + target._modules[cls.name] = params return mod else: raise TargetStableError(f'Module "{cls.name}" is not supported by the target') @@ -89,15 +176,14 @@ def install(cls, target, **params): f'Attempting to install module "{cls.name}" but a module is already installed as attribute "{attr_name}": {mod}' ) - def __init__(self, target): + def __init__(self, target: 'Target'): self.target = target - self.logger = logging.getLogger(self.name) - + self.logger = get_logger(self.name) - def __init_subclass__(cls, *args, **kwargs): + def __init_subclass__(cls, *args, **kwargs) -> None: super().__init_subclass__(*args, **kwargs) - attr_name = cls.kind or cls.name + attr_name: Optional[str] = cls.kind or cls.name cls.attr_name = identifier(attr_name) if attr_name else None if cls.name is not None: @@ -105,21 +191,60 @@ def __init_subclass__(cls, *args, **kwargs): class HardRestModule(Module): + """ + .. attribute:: HardResetModule.kind - kind = 'hard_reset' + "hard_reset" + """ + + kind: str = 'hard_reset' def __call__(self): + """ + .. method:: HardResetModule.__call__() + + Must be implemented by derived classes. + + Implements hard reset for a target devices. The equivalent of physically + power cycling the device. This may be used by client code in situations + where the target becomes unresponsive and/or a regular reboot is not + possible. + """ raise NotImplementedError() class BootModule(Module): + """ + .. attribute:: BootModule.kind - kind = 'boot' + "boot" + """ + + kind: str = 'boot' def __call__(self): + """ + .. method:: BootModule.__call__() + + Must be implemented by derived classes. + + Implements a boot procedure. This takes the device from (hard or soft) + reset to a booted state where the device is ready to accept connections. For + a lot of commercial devices the process is entirely automatic, however some + devices (e.g. development boards), my require additional steps, such as + interactions with the bootloader, in order to boot into the OS. + """ raise NotImplementedError() - def update(self, **kwargs): + def update(self, **kwargs) -> None: + """ + .. method:: Bootmodule.update(**kwargs) + + Update the boot settings. Some boot sequences allow specifying settings + that will be utilized during boot (e.g. linux kernel boot command line). The + default implementation will set each setting in ``kwargs`` as an attribute of + the boot module (or update the existing attribute). + """ for name, value in kwargs.items(): if not hasattr(self, name): raise ValueError('Unknown parameter "{}" for {}'.format(name, self.name)) @@ -128,15 +253,61 @@ def update(self, **kwargs): class FlashModule(Module): - - kind = 'flash' - - def __call__(self, image_bundle=None, images=None, boot_config=None, connect=True): + """ + A Devlib module used for performing firmware or image flashing operations on a target device. + + This module provides an abstraction for managing device flashing, such as flashing new + bootloaders, system images, or recovery partitions, depending on the target platform. + + The `kind` attribute identifies the type of this module and is used by Devlib's internal + module management system to categorize and invoke the appropriate functionality. + + Attributes: + kind (str): The unique identifier for this module type. For `FlashModule`, this is "flash". + + Typical Usage: + This module is automatically loaded onto targets that support flashing operations, + such as development boards or phones with bootloader access. + + Example: + >>> if FlashModule.probe(target): + >>> flash = FlashModule(target) + >>> flash.install() + >>> flash.flash_image("/path/to/image.img", partition="boot") + + Note: + Subclasses of FlashModule should implement the actual flashing logic, as this base + class only provides the interface and identification mechanism. + """ + kind: str = 'flash' + + def __call__(self, image_bundle: Optional[str] = None, + images: Optional[Dict[str, str]] = None, + boot_config: Any = None, connect: bool = True) -> None: + """ + .. method:: __call__(image_bundle=None, images=None, boot_config=None, connect=True) + + Must be implemented by derived classes. + + Flash the target platform with the specified images. + + :param image_bundle: A compressed bundle of image files with any associated + metadata. The format of the bundle is specific to a + particular implementation. + :param images: A dict mapping image names/identifiers to the path on the + host file system of the corresponding image file. If both + this and ``image_bundle`` are specified, individual images + will override those in the bundle. + :param boot_config: Some platforms require specifying boot arguments at the + time of flashing the images, rather than during each + reboot. For other platforms, this will be ignored. + :connect: Specifiy whether to try and connect to the target after flashing. + """ raise NotImplementedError() -def get_module(mod): - def from_registry(mod): +def get_module(mod: Union[str, Type[Module]]) -> Type[Module]: + def from_registry(mod: str): try: return _module_registry[mod] except KeyError: diff --git a/devlib/module/android.py b/devlib/module/android.py index 70564fd05..c9c75ec11 100644 --- a/devlib/module/android.py +++ b/devlib/module/android.py @@ -1,4 +1,4 @@ -# Copyright 2014-2018 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,15 +23,18 @@ from devlib.exception import HostError from devlib.utils.android import fastboot_flash_partition, fastboot_command from devlib.utils.misc import merge_dicts, safe_extract +from typing import (TYPE_CHECKING, Any, Optional, Dict, List, cast) +if TYPE_CHECKING: + from devlib.target import Target, AndroidTarget class FastbootFlashModule(FlashModule): - name = 'fastboot' - description = """ + name: str = 'fastboot' + description: str = """ Enables automated flashing of images using the fastboot utility. - To use this flasher, a set of image files to be flused are required. + To use this flasher, a set of image files to be flashed are required. In addition a mapping between partitions and image file is required. There are two ways to specify those requirements: @@ -47,59 +50,68 @@ class FastbootFlashModule(FlashModule): """ - delay = 0.5 - partitions_file_name = 'partitions.txt' + delay: float = 0.5 + partitions_file_name: str = 'partitions.txt' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return target.os == 'android' - def __call__(self, image_bundle=None, images=None, bootargs=None, connect=True): + def __call__(self, image_bundle: Optional[str] = None, + images: Optional[Dict[str, str]] = None, + bootargs: Any = None, connect: bool = True) -> None: if bootargs: raise ValueError('{} does not support boot configuration'.format(self.name)) - self.prelude_done = False - to_flash = {} + self.prelude_done: bool = False + to_flash: Dict[str, str] = {} if image_bundle: # pylint: disable=access-member-before-definition image_bundle = expand_path(image_bundle) to_flash = self._bundle_to_images(image_bundle) to_flash = merge_dicts(to_flash, images or {}, should_normalize=False) for partition, image_path in to_flash.items(): self.logger.debug('flashing {}'.format(partition)) - self._flash_image(self.target, partition, expand_path(image_path)) + self._flash_image(cast('AndroidTarget', self.target), partition, expand_path(image_path)) fastboot_command('reboot') if connect: self.target.connect(timeout=180) - def _validate_image_bundle(self, image_bundle): + def _validate_image_bundle(self, image_bundle: str) -> None: + """ + make sure the image bundle is a tarfile and it can be opened and it contains the + required partition file + """ if not tarfile.is_tarfile(image_bundle): raise HostError('File {} is not a tarfile'.format(image_bundle)) with tarfile.open(image_bundle) as tar: - files = [tf.name for tf in tar.getmembers()] + files: List[str] = [tf.name for tf in tar.getmembers()] if not any(pf in files for pf in (self.partitions_file_name, '{}/{}'.format(files[0], self.partitions_file_name))): HostError('Image bundle does not contain the required partition file (see documentation)') - def _bundle_to_images(self, image_bundle): + def _bundle_to_images(self, image_bundle: str) -> Dict[str, str]: """ Extracts the bundle to a temporary location and creates a mapping between the contents of the bundle - and images to be flushed. + and images to be flashed. """ self._validate_image_bundle(image_bundle) - extract_dir = tempfile.mkdtemp() + extract_dir: str = tempfile.mkdtemp() with tarfile.open(image_bundle) as tar: safe_extract(tar, path=extract_dir) - files = [tf.name for tf in tar.getmembers()] + files: List[str] = [tf.name for tf in tar.getmembers()] if self.partitions_file_name not in files: extract_dir = os.path.join(extract_dir, files[0]) - partition_file = os.path.join(extract_dir, self.partitions_file_name) + partition_file: str = os.path.join(extract_dir, self.partitions_file_name) return get_mapping(extract_dir, partition_file) - def _flash_image(self, target, partition, image_path): + def _flash_image(self, target: 'AndroidTarget', partition: str, image_path: str) -> None: + """ + flash the image into the partition using fastboot + """ if not self.prelude_done: self._fastboot_prelude(target) fastboot_flash_partition(partition, image_path) time.sleep(self.delay) - def _fastboot_prelude(self, target): + def _fastboot_prelude(self, target: 'AndroidTarget') -> None: target.reset(fastboot=True) time.sleep(self.delay) self.prelude_done = True @@ -107,15 +119,21 @@ def _fastboot_prelude(self, target): # utility functions -def expand_path(original_path): +def expand_path(original_path: str) -> str: + """ + expand ~ and ~user in the path + """ path = os.path.abspath(os.path.expanduser(original_path)) if not os.path.exists(path): raise HostError('{} does not exist.'.format(path)) return path -def get_mapping(base_dir, partition_file): - mapping = {} +def get_mapping(base_dir: str, partition_file: str) -> Dict[str, str]: + """ + get the image and partition mapping info from partition txt file + """ + mapping: Dict[str, str] = {} with open(partition_file) as pf: for line in pf: pair = line.split() diff --git a/devlib/module/biglittle.py b/devlib/module/biglittle.py index 7124f65a5..f3e60258c 100644 --- a/devlib/module/biglittle.py +++ b/devlib/module/biglittle.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,12 @@ # from devlib.module import Module +from devlib.module.hotplug import HotplugModule +from devlib.module.cpufreq import CpufreqModule +from typing import (TYPE_CHECKING, cast, List, + Optional, Dict) +if TYPE_CHECKING: + from devlib.target import Target class BigLittleModule(Module): @@ -21,189 +27,307 @@ class BigLittleModule(Module): name = 'bl' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return target.big_core is not None @property - def bigs(self): + def bigs(self) -> List[int]: + """ + get the list of big cores + """ return [i for i, c in enumerate(self.target.platform.core_names) if c == self.target.platform.big_core] @property - def littles(self): + def littles(self) -> List[int]: + """ + get the list of little cores + """ return [i for i, c in enumerate(self.target.platform.core_names) if c == self.target.platform.little_core] @property - def bigs_online(self): + def bigs_online(self) -> List[int]: + """ + get the list of big cores which are online + """ return list(sorted(set(self.bigs).intersection(self.target.list_online_cpus()))) @property - def littles_online(self): + def littles_online(self) -> List[int]: + """ + get the list of little cores which are online + """ return list(sorted(set(self.littles).intersection(self.target.list_online_cpus()))) # hotplug - def online_all_bigs(self): - self.target.hotplug.online(*self.bigs) - - def offline_all_bigs(self): - self.target.hotplug.offline(*self.bigs) - - def online_all_littles(self): - self.target.hotplug.online(*self.littles) - - def offline_all_littles(self): - self.target.hotplug.offline(*self.littles) + def online_all_bigs(self) -> None: + """ + make all big cores go online + """ + cast(HotplugModule, self.target.hotplug).online(*self.bigs) + + def offline_all_bigs(self) -> None: + """ + make all big cores go offline + """ + cast(HotplugModule, self.target.hotplug).offline(*self.bigs) + + def online_all_littles(self) -> None: + """ + make all little cores go online + """ + cast(HotplugModule, self.target.hotplug).online(*self.littles) + + def offline_all_littles(self) -> None: + """ + make all little cores go offline + """ + cast(HotplugModule, self.target.hotplug).offline(*self.littles) # cpufreq - def list_bigs_frequencies(self): + def list_bigs_frequencies(self) -> Optional[List[int]]: + """ + get the big cores frequencies + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.list_frequencies(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_frequencies(bigs_online[0]) + return None - def list_bigs_governors(self): + def list_bigs_governors(self) -> Optional[List[str]]: + """ + get the governors supported for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.list_governors(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_governors(bigs_online[0]) + return None - def list_bigs_governor_tunables(self): + def list_bigs_governor_tunables(self) -> Optional[List[str]]: + """ + get the tunable governors supported for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.list_governor_tunables(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_governor_tunables(bigs_online[0]) + return None - def list_littles_frequencies(self): + def list_littles_frequencies(self) -> Optional[List[int]]: + """ + get the little cores frequencies + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.list_frequencies(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_frequencies(littles_online[0]) + return None - def list_littles_governors(self): + def list_littles_governors(self) -> Optional[List[str]]: + """ + get the governors supported for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.list_governors(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_governors(littles_online[0]) + return None - def list_littles_governor_tunables(self): + def list_littles_governor_tunables(self) -> Optional[List[str]]: + """ + get the tunable governors supported for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.list_governor_tunables(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_governor_tunables(littles_online[0]) + return None - def get_bigs_governor(self): + def get_bigs_governor(self) -> Optional[str]: + """ + get the current governor set for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.get_governor(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_governor(bigs_online[0]) + return None - def get_bigs_governor_tunables(self): + def get_bigs_governor_tunables(self) -> Optional[Dict[str, str]]: + """ + get the current governor tunables set for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.get_governor_tunables(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_governor_tunables(bigs_online[0]) + return None - def get_bigs_frequency(self): + def get_bigs_frequency(self) -> Optional[int]: + """ + get the current frequency that is set for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.get_frequency(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_frequency(bigs_online[0]) + return None - def get_bigs_min_frequency(self): + def get_bigs_min_frequency(self) -> Optional[int]: + """ + get the current minimum frequency that is set for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.get_min_frequency(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_min_frequency(bigs_online[0]) + return None - def get_bigs_max_frequency(self): + def get_bigs_max_frequency(self) -> Optional[int]: + """ + get the current maximum frequency that is set for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.get_max_frequency(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_max_frequency(bigs_online[0]) + return None - def get_littles_governor(self): + def get_littles_governor(self) -> Optional[str]: + """ + get the current governor set for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.get_governor(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_governor(littles_online[0]) + return None - def get_littles_governor_tunables(self): + def get_littles_governor_tunables(self) -> Optional[Dict[str, str]]: + """ + get the current governor tunables set for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.get_governor_tunables(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_governor_tunables(littles_online[0]) + return None - def get_littles_frequency(self): + def get_littles_frequency(self) -> Optional[int]: + """ + get the current frequency that is set for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.get_frequency(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_frequency(littles_online[0]) + return None - def get_littles_min_frequency(self): + def get_littles_min_frequency(self) -> Optional[int]: + """ + get the current minimum frequency that is set for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.get_min_frequency(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_min_frequency(littles_online[0]) + return None - def get_littles_max_frequency(self): + def get_littles_max_frequency(self) -> Optional[int]: + """ + get the current maximum frequency that is set for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.get_max_frequency(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_max_frequency(littles_online[0]) + return None - def set_bigs_governor(self, governor, **kwargs): + def set_bigs_governor(self, governor: str, **kwargs) -> None: + """ + set governor for the first online big core + """ bigs_online = self.bigs_online if bigs_online: - self.target.cpufreq.set_governor(bigs_online[0], governor, **kwargs) + cast(CpufreqModule, self.target.cpufreq).set_governor(bigs_online[0], governor, **kwargs) else: raise ValueError("All bigs appear to be offline") - def set_bigs_governor_tunables(self, governor, **kwargs): + def set_bigs_governor_tunables(self, governor: str, **kwargs) -> None: + """ + set governor tunables for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - self.target.cpufreq.set_governor_tunables(bigs_online[0], governor, **kwargs) + cast(CpufreqModule, self.target.cpufreq).set_governor_tunables(bigs_online[0], governor, **kwargs) else: raise ValueError("All bigs appear to be offline") - def set_bigs_frequency(self, frequency, exact=True): + def set_bigs_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the frequency for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - self.target.cpufreq.set_frequency(bigs_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_frequency(bigs_online[0], frequency, exact) else: raise ValueError("All bigs appear to be offline") - def set_bigs_min_frequency(self, frequency, exact=True): + def set_bigs_min_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the minimum value for the cpu frequency for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - self.target.cpufreq.set_min_frequency(bigs_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_min_frequency(bigs_online[0], frequency, exact) else: raise ValueError("All bigs appear to be offline") - def set_bigs_max_frequency(self, frequency, exact=True): + def set_bigs_max_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the minimum value for the cpu frequency for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - self.target.cpufreq.set_max_frequency(bigs_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_max_frequency(bigs_online[0], frequency, exact) else: raise ValueError("All bigs appear to be offline") - def set_littles_governor(self, governor, **kwargs): + def set_littles_governor(self, governor: str, **kwargs) -> None: + """ + set governor for the first online little core + """ littles_online = self.littles_online if littles_online: - self.target.cpufreq.set_governor(littles_online[0], governor, **kwargs) + cast(CpufreqModule, self.target.cpufreq).set_governor(littles_online[0], governor, **kwargs) else: raise ValueError("All littles appear to be offline") - def set_littles_governor_tunables(self, governor, **kwargs): + def set_littles_governor_tunables(self, governor: str, **kwargs) -> None: + """ + set governor tunables for the first little core that is online + """ littles_online = self.littles_online if littles_online: - self.target.cpufreq.set_governor_tunables(littles_online[0], governor, **kwargs) + cast(CpufreqModule, self.target.cpufreq).set_governor_tunables(littles_online[0], governor, **kwargs) else: raise ValueError("All littles appear to be offline") - def set_littles_frequency(self, frequency, exact=True): + def set_littles_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the frequency for the first little core that is online + """ littles_online = self.littles_online if littles_online: - self.target.cpufreq.set_frequency(littles_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_frequency(littles_online[0], frequency, exact) else: raise ValueError("All littles appear to be offline") - def set_littles_min_frequency(self, frequency, exact=True): + def set_littles_min_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the minimum value for the cpu frequency for the first little core that is online + """ littles_online = self.littles_online if littles_online: - self.target.cpufreq.set_min_frequency(littles_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_min_frequency(littles_online[0], frequency, exact) else: raise ValueError("All littles appear to be offline") - def set_littles_max_frequency(self, frequency, exact=True): + def set_littles_max_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the maximum value for the cpu frequency for the first little core that is online + """ littles_online = self.littles_online if littles_online: - self.target.cpufreq.set_max_frequency(littles_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_max_frequency(littles_online[0], frequency, exact) else: raise ValueError("All littles appear to be offline") diff --git a/devlib/module/cgroups.py b/devlib/module/cgroups.py index a7edf879c..3bbda242d 100644 --- a/devlib/module/cgroups.py +++ b/devlib/module/cgroups.py @@ -1,4 +1,4 @@ -# Copyright 2014-2018 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,71 +22,76 @@ from devlib.module import Module from devlib.exception import TargetStableError -from devlib.utils.misc import list_to_ranges, isiterable +from devlib.utils.misc import list_to_ranges, isiterable, get_logger from devlib.utils.types import boolean from devlib.utils.asyn import asyncf, run +from typing import (TYPE_CHECKING, Optional, List, Dict, + Union, Tuple, cast, Set) + +if TYPE_CHECKING: + from devlib.target import Target, FstabEntry class Controller(object): - def __init__(self, kind, hid, clist): + def __init__(self, kind: str, hid: int, clist: List[str]): """ Initialize a controller given the hierarchy it belongs to. :param kind: the name of the controller - :type kind: str :param hid: the Hierarchy ID this controller is mounted on - :type hid: int :param clist: the list of controller mounted in the same hierarchy - :type clist: list(str) """ - self.mount_name = 'devlib_cgh{}'.format(hid) - self.kind = kind - self.hid = hid - self.clist = clist - self.target = None - self._noprefix = False - - self.logger = logging.getLogger('CGroup.'+self.kind) + self.mount_name: str = 'devlib_cgh{}'.format(hid) + self.kind: str = kind + self.hid: int = hid + self.clist: List[str] = clist + self.target: Optional['Target'] = None + self._noprefix: bool = False + + self.logger: logging.Logger = get_logger('CGroup.' + self.kind) self.logger.debug('Initialized [%s, %d, %s]', self.kind, self.hid, self.clist) - self.mount_point = None - self._cgroups = {} + self.mount_point: Optional[str] = None + self._cgroups: Dict[str, 'CGroup'] = {} @asyncf - async def mount(self, target, mount_root): - - mounted = target.list_file_systems() + async def mount(self, target: 'Target', mount_root: str) -> None: + """ + mount the controller in mount point + """ + mounted: List[FstabEntry] = target.list_file_systems() if self.mount_name in [e.device for e in mounted]: # Identify mount point if controller is already in use self.mount_point = [ - fs.mount_point - for fs in mounted - if fs.device == self.mount_name - ][0] + fs.mount_point + for fs in mounted + if fs.device == self.mount_name + ][0] else: # Mount the controller if not already in use - self.mount_point = target.path.join(mount_root, self.mount_name) - await target.execute.asyn('mkdir -p {} 2>/dev/null'\ - .format(self.mount_point), as_root=True) - await target.execute.asyn('mount -t cgroup -o {} {} {}'\ - .format(','.join(self.clist), - self.mount_name, - self.mount_point), - as_root=True) + if target.path: + self.mount_point = target.path.join(mount_root, self.mount_name) + await target.execute.asyn('mkdir -p {} 2>/dev/null' + .format(self.mount_point), as_root=True) + await target.execute.asyn('mount -t cgroup -o {} {} {}' + .format(','.join(self.clist), + self.mount_name, + self.mount_point), + as_root=True) # Check if this controller uses "noprefix" option - output = await target.execute.asyn('mount | grep "{} "'.format(self.mount_name)) + output: str = await target.execute.asyn('mount | grep "{} "'.format(self.mount_name)) if 'noprefix' in output: self._noprefix = True # self.logger.debug('Controller %s using "noprefix" option', # self.kind) self.logger.debug('Controller %s mounted under: %s (noprefix=%s)', - self.kind, self.mount_point, self._noprefix) + self.kind, self.mount_point, self._noprefix) # Mark this contoller as available self.target = target @@ -94,39 +99,51 @@ async def mount(self, target, mount_root): # Create root control group self.cgroup('/') - def cgroup(self, name): + def cgroup(self, name: str) -> 'CGroup': + """ + get the control group with the name + """ if not self.target: - raise RuntimeError('CGroup creation failed: {} controller not mounted'\ - .format(self.kind)) + raise RuntimeError('CGroup creation failed: {} controller not mounted' + .format(self.kind)) if name not in self._cgroups: self._cgroups[name] = CGroup(self, name) return self._cgroups[name] - def exists(self, name): + def exists(self, name: str) -> bool: + """ + returns True if the control group with this name exists + """ if not self.target: - raise RuntimeError('CGroup creation failed: {} controller not mounted'\ - .format(self.kind)) + raise RuntimeError('CGroup creation failed: {} controller not mounted' + .format(self.kind)) if name not in self._cgroups: self._cgroups[name] = CGroup(self, name, create=False) return self._cgroups[name].exists() - def list_all(self): + def list_all(self) -> List[str]: + """ + List all control groups for this controller + """ self.logger.debug('Listing groups for %s controller', self.kind) - output = self.target.execute('{} find {} -type d'\ - .format(self.target.busybox, self.mount_point), - as_root=True) - cgroups = [] - for cg in output.splitlines(): - cg = cg.replace(self.mount_point + '/', '/') - cg = cg.replace(self.mount_point, '/') - cg = cg.strip() - if cg == '': - continue - self.logger.debug('Populate %s cgroup: %s', self.kind, cg) - cgroups.append(cg) + if self.target: + output: str = self.target.execute('{} find {} -type d' + .format(self.target.busybox, self.mount_point), + as_root=True) + cgroups: List[str] = [] + for cg in output.splitlines(): + if self.mount_point: + cg = cg.replace(self.mount_point + '/', '/') + cg = cg.replace(self.mount_point, '/') + cg = cg.strip() + if cg == '': + continue + self.logger.debug('Populate %s cgroup: %s', self.kind, cg) + cgroups.append(cg) return cgroups - def move_tasks(self, source, dest, exclude=None): + def move_tasks(self, source: str, dest: str, + exclude: Optional[Union[str, List[str]]] = None) -> None: if isinstance(exclude, str): warnings.warn("Controller.move_tasks() takes needs a _list_ of exclude patterns, not a string", DeprecationWarning) exclude = [exclude] @@ -143,17 +160,17 @@ def move_tasks(self, source, dest, exclude=None): srcg = self.cgroup(source) dstg = self.cgroup(dest) + if self.target and srcg.directory and dstg.directory: + self.target._execute_util( # pylint: disable=protected-access + 'cgroups_tasks_move {src} {dst} {exclude}'.format( + src=quote(srcg.directory), + dst=quote(dstg.directory), + exclude=exclude, + ), + as_root=True, + ) - self.target._execute_util( # pylint: disable=protected-access - 'cgroups_tasks_move {src} {dst} {exclude}'.format( - src=quote(srcg.directory), - dst=quote(dstg.directory), - exclude=exclude, - ), - as_root=True, - ) - - def move_all_tasks_to(self, dest, exclude=None): + def move_all_tasks_to(self, dest: str, exclude: Optional[Union[str, List[str]]] = None) -> None: """ Move all the tasks to the specified CGroup @@ -166,7 +183,6 @@ def move_all_tasks_to(self, dest, exclude=None): tasks. :param exclude: list of commands to keep in the root CGroup - :type exclude: list(str) """ if exclude is None: exclude = [] @@ -187,10 +203,10 @@ def move_all_tasks_to(self, dest, exclude=None): self.move_tasks(cgroup, dest, exclude) # pylint: disable=too-many-locals - def tasks(self, cgroup, - filter_tid='', - filter_tname='', - filter_tcmdline=''): + def tasks(self, cgroup: str, + filter_tid: str = '', + filter_tname: str = '', + filter_tcmdline: str = '') -> Dict[int, Tuple[str, str]]: """ Report the tasks that are included in a cgroup. The tasks can be filtered by their tid, tname or tcmdline if filter_tid, filter_tname or @@ -202,13 +218,10 @@ def tasks(self, cgroup, 903,cameraserver,/system/bin/cameraserver :params filter_tid: regexp pattern to filter by TID - :type filter_tid: str :params filter_tname: regexp pattern to filter by tname - :type filter_tname: str :params filter_tcmdline: regexp pattern to filter by tcmdline - :type filter_tcmdline: str :returns: a dictionary in the form: {tid:(tname, tcmdline)} """ @@ -222,14 +235,16 @@ def tasks(self, cgroup, cg = self._cgroups[cgroup] except KeyError as e: raise ValueError('Unknown group: {}'.format(e)) - output = self.target._execute_util( # pylint: disable=protected-access - 'cgroups_tasks_in {}'.format(cg.directory), - as_root=True) - entries = output.splitlines() - tasks = {} + if self.target is None: + raise ValueError("Target is None") + output: str = self.target._execute_util( # pylint: disable=protected-access + 'cgroups_tasks_in {}'.format(cg.directory), + as_root=True) + entries: List[str] = output.splitlines() + tasks: Dict[int, Tuple[str, str]] = {} for task in entries: - fields = task.split(',', 2) - nr_fields = len(fields) + fields: List[str] = task.split(',', 2) + nr_fields: int = len(fields) if nr_fields < 2: continue elif nr_fields == 2: @@ -248,65 +263,86 @@ def tasks(self, cgroup, tasks[int(tid_str)] = (tname, tcmdline) return tasks - def tasks_count(self, cgroup): + def tasks_count(self, cgroup: str) -> int: + """ + count of the number of tasks in the cgroup + """ try: cg = self._cgroups[cgroup] except KeyError as e: raise ValueError('Unknown group: {}'.format(e)) + if self.target is None: + raise ValueError("Target is None") output = self.target.execute( - '{} wc -l {}/tasks'.format( - self.target.busybox, cg.directory), - as_root=True) + '{} wc -l {}/tasks'.format( + self.target.busybox, cg.directory), + as_root=True) return int(output.split()[0]) - def tasks_per_group(self): - tasks = {} + def tasks_per_group(self) -> Dict[str, int]: + """ + tasks in all cgroups + """ + tasks: Dict[str, int] = {} for cg in self.list_all(): tasks[cg] = self.tasks_count(cg) return tasks + class CGroup(object): - def __init__(self, controller, name, create=True): - self.logger = logging.getLogger('cgroups.' + controller.kind) - self.target = controller.target - self.controller = controller - self.name = name + def __init__(self, controller: 'Controller', name: str, create: bool = True): + self.logger: logging.Logger = get_logger('cgroups.' + controller.kind) + self.target: Optional['Target'] = controller.target + self.controller: Controller = controller + self.name: str = name # Control cgroup path - self.directory = controller.mount_point - + self.directory: Optional[str] = controller.mount_point + if self.target is None: + raise ValueError("Target is None") + if self.target.path is None: + raise ValueError("Target.path is None") if name != '/': self.directory = self.target.path.join(controller.mount_point, name.strip('/')) # Setup path for tasks file - self.tasks_file = self.target.path.join(self.directory, 'tasks') - self.procs_file = self.target.path.join(self.directory, 'cgroup.procs') + self.tasks_file: str = self.target.path.join(self.directory, 'tasks') + self.procs_file: str = self.target.path.join(self.directory, 'cgroup.procs') if not create: return self.logger.debug('Creating cgroup %s', self.directory) - self.target.execute('[ -d {0} ] || mkdir -p {0}'\ - .format(self.directory), as_root=True) + self.target.execute('[ -d {0} ] || mkdir -p {0}' + .format(self.directory), as_root=True) - def exists(self): + def exists(self) -> bool: + """ + return true if the directory of the control group exists + """ try: - self.target.execute('[ -d {0} ]'\ - .format(self.directory), as_root=True) + if self.target is None: + raise TargetStableError + self.target.execute('[ -d {0} ]' + .format(self.directory), as_root=True) return True except TargetStableError: return False - def get(self): - conf = {} - + def get(self) -> Dict[str, str]: + """ + get attributes and associated value from control groups + """ + conf: Dict[str, str] = {} + if self.target is None: + raise ValueError("Target is None") self.logger.debug('Reading %s attributes from:', self.controller.kind) self.logger.debug(' %s', self.directory) - output = self.target._execute_util( # pylint: disable=protected-access - 'cgroups_get_attributes {} {}'.format( - self.directory, self.controller.kind), - as_root=True) + output: str = self.target._execute_util( # pylint: disable=protected-access + 'cgroups_get_attributes {} {}'.format( + self.directory, self.controller.kind), + as_root=True) for res in output.splitlines(): attr = res.split(':')[0] value = res.split(':')[1] @@ -314,78 +350,98 @@ def get(self): return conf - def set(self, **attrs): + def set(self, **attrs: Union[str, List[int], int]) -> None: + """ + set attributes to the control group + """ for idx in attrs: if isiterable(attrs[idx]): - attrs[idx] = list_to_ranges(attrs[idx]) + attrs[idx] = list_to_ranges(cast(List, attrs[idx])) # Build attribute path if self.controller._noprefix: # pylint: disable=protected-access attr_name = '{}'.format(idx) else: attr_name = '{}.{}'.format(self.controller.kind, idx) - path = self.target.path.join(self.directory, attr_name) + path: str = self.target.path.join(self.directory, attr_name) if self.target and self.target.path else '' self.logger.debug('Set attribute [%s] to: %s"', - path, attrs[idx]) + path, attrs[idx]) # Set the attribute value try: - self.target.write_value(path, attrs[idx]) + if self.target: + self.target.write_value(path, attrs[idx]) except TargetStableError: # Check if the error is due to a non-existing attribute - attrs = self.get() - if idx not in attrs: - raise ValueError('Controller [{}] does not provide attribute [{}]'\ + attrs_int = self.get() + if idx not in attrs_int: + raise ValueError('Controller [{}] does not provide attribute [{}]' .format(self.controller.kind, attr_name)) raise - def get_tasks(self): - task_ids = self.target.read_value(self.tasks_file).split() + def get_tasks(self) -> List[int]: + """ + get the ids of tasks in the control group + """ + task_ids: List[str] = self.target.read_value(self.tasks_file).split() if self.target else [] self.logger.debug('Tasks: %s', task_ids) return list(map(int, task_ids)) - def add_task(self, tid): - self.target.write_value(self.tasks_file, tid, verify=False) + def add_task(self, tid: int) -> None: + """ + add task to the control group + """ + if self.target: + self.target.write_value(self.tasks_file, tid, verify=False) - def add_tasks(self, tasks): + def add_tasks(self, tasks: List[int]) -> None: + """ + add multiple tasks to the control group + """ for tid in tasks: self.add_task(tid) - def add_proc(self, pid): - self.target.write_value(self.procs_file, pid, verify=False) + def add_proc(self, pid: int) -> None: + """ + add process to the control group + """ + if self.target: + self.target.write_value(self.procs_file, pid, verify=False) + CgroupSubsystemEntry = namedtuple('CgroupSubsystemEntry', 'name hierarchy num_cgroups enabled') + class CgroupsModule(Module): - name = 'cgroups' - stage = 'setup' + name: str = 'cgroups' + stage: str = 'setup' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: if not target.is_rooted: return False if target.file_exists('/proc/cgroups'): return True return target.config.has('cgroups') - def __init__(self, target): + def __init__(self, target: 'Target'): super(CgroupsModule, self).__init__(target) - self.logger = logging.getLogger('CGroups') + self.logger: logging.Logger = get_logger('CGroups') # Set Devlib's CGroups mount point - self.cgroup_root = target.path.join( - target.working_directory, 'cgroups') + self.cgroup_root: str = target.path.join( + target.working_directory, 'cgroups') if target.path else '' # Get the list of the available controllers - subsys = self.list_subsystems() + subsys: List['CgroupSubsystemEntry'] = self.list_subsystems() if not subsys: self.logger.warning('No CGroups controller available') return # Map hierarchy IDs into a list of controllers - hierarchy = {} + hierarchy: Dict[int, List[str]] = {} for ss in subsys: try: hierarchy[ss.hierarchy].append(ss.name) @@ -395,10 +451,13 @@ def __init__(self, target): # Initialize controllers self.logger.info('Available controllers:') - self.controllers = {} + self.controllers: Dict[str, Controller] = {} - async def register_controller(ss): - hid = ss.hierarchy + async def register_controller(ss: 'CgroupSubsystemEntry') -> None: + """ + register controller to control group module + """ + hid: int = ss.hierarchy controller = Controller(ss.name, hid, hierarchy[hid]) try: await controller.mount.asyn(self.target, self.cgroup_root) @@ -416,11 +475,13 @@ async def register_controller(ss): ) ) - - def list_subsystems(self): - subsystems = [] - for line in self.target.execute('{} cat /proc/cgroups'\ - .format(self.target.busybox), as_root=self.target.is_rooted).splitlines()[1:]: + def list_subsystems(self) -> List['CgroupSubsystemEntry']: + """ + get the list of subsystems as a list of class:CgroupSubsystemEntry objects + """ + subsystems: List['CgroupSubsystemEntry'] = [] + for line in self.target.execute('{} cat /proc/cgroups' + .format(self.target.busybox), as_root=self.target.is_rooted).splitlines()[1:]: line = line.strip() if not line or line.startswith('#') or line.endswith('0'): continue @@ -431,14 +492,16 @@ def list_subsystems(self): boolean(enabled))) return subsystems - - def controller(self, kind): + def controller(self, kind: str) -> Optional[Controller]: + """ + get the controller of the specified kind + """ if kind not in self.controllers: self.logger.warning('Controller %s not available', kind) return None return self.controllers[kind] - def run_into_cmd(self, cgroup, cmdline): + def run_into_cmd(self, cgroup: str, cmdline: str) -> str: """ Get the command to run a command into a given cgroup @@ -450,10 +513,10 @@ def run_into_cmd(self, cgroup, cmdline): message = 'cgroup name "{}" must start with "/"'.format(cgroup) raise ValueError(message) return 'CGMOUNT={} {} cgroups_run_into {} {}'\ - .format(self.cgroup_root, self.target.shutils, - cgroup, cmdline) + .format(self.cgroup_root, self.target.shutils, + cgroup, cmdline) - def run_into(self, cgroup, cmdline, as_root=None): + def run_into(self, cgroup: str, cmdline: str, as_root: Optional[bool] = None) -> str: """ Run the specified command into the specified CGroup @@ -465,13 +528,13 @@ def run_into(self, cgroup, cmdline, as_root=None): """ if as_root is None: as_root = self.target.is_rooted - cmd = self.run_into_cmd(cgroup, cmdline) - raw_output = self.target.execute(cmd, as_root=as_root) + cmd: str = self.run_into_cmd(cgroup, cmdline) + raw_output: str = self.target.execute(cmd, as_root=as_root) # First line of output comes from shutils; strip it out. return raw_output.split('\n', 1)[1] - def cgroups_tasks_move(self, srcg, dstg, exclude=''): + def cgroups_tasks_move(self, srcg: str, dstg: str, exclude: str = '') -> str: """ Move all the tasks from the srcg CGroup to the dstg one. A regexps of tasks names can be used to defined tasks which should not @@ -481,7 +544,7 @@ def cgroups_tasks_move(self, srcg, dstg, exclude=''): 'cgroups_tasks_move {} {} {}'.format(srcg, dstg, exclude), as_root=True) - def isolate(self, cpus, exclude=None): + def isolate(self, cpus: List[int], exclude: Optional[List[str]] = None) -> Tuple[CGroup, CGroup]: """ Remove all userspace tasks from specified CPUs. @@ -492,7 +555,6 @@ def isolate(self, cpus, exclude=None): tasks running unless explicitely moved into the isolated group. :param cpus: the list of CPUs to isolate - :type cpus: list(int) :return: the (sandbox, isolated) tuple, where: sandbox is the CGroup of sandboxed CPUs @@ -500,12 +562,14 @@ def isolate(self, cpus, exclude=None): """ if exclude is None: exclude = [] - all_cpus = set(range(self.target.number_of_cpus)) - sbox_cpus = list(all_cpus - set(cpus)) - isol_cpus = list(all_cpus - set(sbox_cpus)) + all_cpus: Set[int] = set(range(self.target.number_of_cpus)) + sbox_cpus: List[int] = list(all_cpus - set(cpus)) + isol_cpus: List[int] = list(all_cpus - set(sbox_cpus)) # Create Sandbox and Isolated cpuset CGroups - cpuset = self.controller('cpuset') + cpuset: Optional[Controller] = self.controller('cpuset') + if cpuset is None: + raise ValueError("cpuset is None") sbox_cg = cpuset.cgroup('/DEVLIB_SBOX') isol_cg = cpuset.cgroup('/DEVLIB_ISOL') @@ -518,7 +582,8 @@ def isolate(self, cpus, exclude=None): return sbox_cg, isol_cg - def freeze(self, exclude=None, thaw=False): + def freeze(self, exclude: Optional[List[str]] = None, + thaw: bool = False) -> Optional[Dict[int, Tuple[str, str]]]: """ Freeze all user-space tasks but the specified ones @@ -530,10 +595,8 @@ def freeze(self, exclude=None, thaw=False): the PID of these tasks. :param exclude: list of commands paths to exclude from freezer - :type exclude: list(str) :param thaw: if true thaw tasks instead - :type thaw: bool """ if exclude is None: @@ -549,10 +612,11 @@ def freeze(self, exclude=None, thaw=False): if thaw: # Restart frozen tasks # pylint: disable=protected-access - freezer.target._execute_util(cmd.format('THAWED'), as_root=True) - # Remove all tasks from freezer - freezer.move_all_tasks_to('/') - return + if freezer.target: + freezer.target._execute_util(cmd.format('THAWED'), as_root=True) + # Remove all tasks from freezer + freezer.move_all_tasks_to('/') + return None # Move all tasks into the freezer group freezer.move_all_tasks_to('/DEVLIB_FREEZER', exclude) @@ -562,6 +626,7 @@ def freeze(self, exclude=None, thaw=False): # Freeze all tasks # pylint: disable=protected-access - freezer.target._execute_util(cmd.format('FROZEN'), as_root=True) + if freezer.target: + freezer.target._execute_util(cmd.format('FROZEN'), as_root=True) return tasks diff --git a/devlib/module/cgroups2.py b/devlib/module/cgroups2.py index a632bbee6..b771cbd72 100644 --- a/devlib/module/cgroups2.py +++ b/devlib/module/cgroups2.py @@ -1,4 +1,4 @@ -# Copyright 2022 ARM Limited +# Copyright 2022-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -101,7 +101,11 @@ from abc import ABC, abstractmethod from contextlib import ExitStack, contextmanager from shlex import quote -from typing import Dict, Set, List, Union, Any +from typing import (Dict, Set, List, Union, Any, + Tuple, cast, Callable, Optional, + Pattern, Match) +from collections.abc import Generator +from contextlib import _GeneratorContextManager from uuid import uuid4 from devlib import LinuxTarget @@ -112,16 +116,17 @@ from devlib.target import FstabEntry from devlib.utils.misc import memoized +# dictionary type frequently being used in this module +ControllerDict = Dict[str, Dict[str, Union[str, int]]] + def _is_systemd_online(target: LinuxTarget): """ Determines if systemd is activated on the target system. :param target: Interface to the target device. - :type target: Target :return: Returns ``True`` if systemd is active, ``False`` otherwise. - :rtype: bool """ try: @@ -132,19 +137,16 @@ def _is_systemd_online(target: LinuxTarget): return True -def _read_lines(target: LinuxTarget, path: str): +def _read_lines(target: LinuxTarget, path: str) -> List[str]: """ Reads the lines of a file stored on the target device. :param target: Interface to target device. - :type target: Target :param path: The path to the file to be read. - :type path: str :return: A list of the words/sentences that result from splitting the read file (trailing and leading white-spaces removed) delimiting on the new-line character. - :rtype: List[str] """ return target.read_value(path=path).split("\n") @@ -157,19 +159,20 @@ def _add_controller_versions(controllers: Dict[str, Dict[str, int]]): :param controllers: A dictionary mapping ``str`` controller names to dictionaries, where the later dictionary contains ``hierarchy`` and ``num_cgroup`` keys mapped to their respective suitable ``int`` values. - :type controllers: Dict[str, Dict[str, int]] :return: A dictionary mapping ``str`` controller names to dictionaries, where the later dictionary contains an appended ``version`` key which maps to an ``int`` value representing the version of the respective controller if applicable. - :rtype: Dict[str, Dict[str,int]] """ # Read how the controller versions can be determined here: # https://man7.org/linux/man-pages/man7/cgroups.7.html # (Under NOTES) [Dated 12/08/2022] - def infer_version(config): + def infer_version(config: Dict[str, int]) -> Optional[int]: + """ + determine the controller version + """ if config["hierarchy"] != 0: return 1 elif config["hierarchy"] == 0 and config["num_cgroups"] > 1: @@ -188,31 +191,31 @@ def infer_version(config): def _add_controller_mounts( controllers: Dict[str, Dict[str, int]], target_fs_list: List[FstabEntry] -): +) -> Dict[str, Dict[str, Union[str, int]]]: """ Find the CGroup controller's mount point and adds it as ``mount_point`` key. :param controllers: A dictionary mapping ``str`` controller names to dictionaries, where the later dictionary contains `hierarchy``, ``num_cgroup`` and if appropriate ``version`` keys mapped to their respective suitable ``int`` values. - :type controllers: Dict[str, Dict[str, int]] :param target_fs_list: A list of entries of the NamedTuple type ``FstabEntry``, where each represents a mounted filesystem on the target device. - :type target_fs: List[FstabEntry] :return: A dictionary mapping ``str`` controller names to dictionaries, where the later dictionary contains an appended ``mount_point`` key which maps to the suitable ``str`` value of the respective controllers if applicable. - :rtype: Dict[str, Dict[str, Union[str,int]]] """ # Filter the mounted filesystems on the target device, obtaining the respective V1/V2 FstabEntries. - v1_mounts = [fs for fs in target_fs_list if fs.fs_type == "cgroup"] - v2_mounts = [fs for fs in target_fs_list if fs.fs_type == "cgroup2"] + v1_mounts: List[FstabEntry] = [fs for fs in target_fs_list if fs.fs_type == "cgroup"] + v2_mounts: List[FstabEntry] = [fs for fs in target_fs_list if fs.fs_type == "cgroup2"] - def _infer_mount(controller: str, configuration: Dict): - controller_version = configuration.get("version") + def _infer_mount(controller: str, configuration: Dict[str, int]) -> Optional[str]: + """ + determine the controller mount point + """ + controller_version: Optional[int] = configuration.get("version") if controller_version == 1: for mount in v1_mounts: if controller in mount.options.strip().split(","): @@ -225,7 +228,7 @@ def _infer_mount(controller: str, configuration: Dict): return None - return { + return cast(Dict[str, Dict[str, Union[str, int]]], { controller: {**config, "mount_point": path if path is not None else config} for (controller, config, path) in ( ( @@ -235,20 +238,18 @@ def _infer_mount(controller: str, configuration: Dict): ) for (controller, config) in controllers.items() ) - } + }) -def _get_cgroup_controllers(target: LinuxTarget): +def _get_cgroup_controllers(target: LinuxTarget) -> ControllerDict: """ Returns the CGroup controllers that are currently enabled on the target device, alongside their appropriate configurations. :param target: Interface to target device. - :type target: Target :return: A dictionary of controller name keys to dictionary value mappings, where the secondary dictionary contains a mapping between various CGroup controller configuration keys and their respectively obtained values for the respective CGroup controllers. - :rtype: Dict[str, Dict[str,Union[str,int]]] """ # A snippet of the /proc/cgroup is shown below. The column entries are separated @@ -257,19 +258,23 @@ def _get_cgroup_controllers(target: LinuxTarget): # #subsys_name hierarchy num_cgroups enabled # cpuset 3 1 1 - PROC_MOUNT_REGEX = re.compile( + PROC_MOUNT_REGEX: Pattern[str] = re.compile( r"^(?!#)(?P.+)\t(?P.+)\t(?P.+)\t(?P.+)" ) - proc_cgroup_file = _read_lines(target=target, path="/proc/cgroups") + proc_cgroup_file: List[str] = _read_lines(target=target, path="/proc/cgroups") - def _parse_controllers(controller): + def _parse_controllers(controller: str) -> Union[Tuple[str, Dict[str, int]], + Tuple[None, None]]: + """ + parse the controllers information from cgroups file + """ match = PROC_MOUNT_REGEX.match(controller.strip()) if match: - name = match.group("name") - enabled = int(match.group("enabled")) - hierarchy = int(match.group("hierarchy")) - num_cgroups = int(match.group("num_cgroups")) + name: str = match.group("name") + enabled: int = int(match.group("enabled")) + hierarchy: int = int(match.group("hierarchy")) + num_cgroups: int = int(match.group("num_cgroups")) # We should ignore disabled controllers. if enabled != 0: config = { @@ -279,9 +284,9 @@ def _parse_controllers(controller): return (name, config) return (None, None) - controllers = dict(map(_parse_controllers, proc_cgroup_file)) - controllers.pop(None) - controllers = _add_controller_versions(controllers=controllers) + controllers_temp = dict(map(_parse_controllers, proc_cgroup_file)) + controllers_temp.pop(None) + controllers = _add_controller_versions(controllers=cast(Dict[str, Dict[str, int]], controllers_temp)) controllers = _add_controller_mounts( controllers=controllers, target_fs_list=target.list_file_systems(), @@ -291,12 +296,11 @@ def _parse_controllers(controller): @contextmanager -def _request_delegation(target: LinuxTarget): +def _request_delegation(target: LinuxTarget) -> Generator[int, None, None]: """ Requests systemd to delegate a subtree CGroup hierarchy to our transient service unit. :yield: The Main PID of the delegated transient service unit. - :rtype: int """ service_name = "devlib-" + str(uuid.uuid4().hex) @@ -304,7 +308,7 @@ def _request_delegation(target: LinuxTarget): try: target.execute( 'systemd-run --no-block --property Delegate="yes" --unit {name} --quiet {busybox} sh -c "while true; do sleep 1d; done"'.format( - name=quote(service_name), busybox=quote(target.busybox) + name=quote(service_name), busybox=quote(target.busybox or '') ), as_root=True, ) @@ -326,15 +330,13 @@ def _request_delegation(target: LinuxTarget): @contextmanager -def _mount_v2_controllers(target: LinuxTarget): +def _mount_v2_controllers(target: LinuxTarget) -> Generator[str, None, None]: """ Mounts the V2 unified CGroup controller hierarchy. :param target: Interface to target device. - :type target: Target :yield: The path to the root of the mounted V2 controller hierarchy. - :rtype: str :raises TargetStableError: Occurs in the case where the root directory of the requested CGroup V2 Controller hierarchy is unable to be created up on the target system. @@ -344,7 +346,7 @@ def _mount_v2_controllers(target: LinuxTarget): try: target.execute( "{busybox} mount -t cgroup2 none {path}".format( - busybox=quote(target.busybox), path=quote(path) + busybox=quote(target.busybox or ''), path=quote(path) ), as_root=True, ) @@ -352,7 +354,7 @@ def _mount_v2_controllers(target: LinuxTarget): finally: target.execute( "{busybox} umount {path}".format( - busybox=quote(target.busybox), + busybox=quote(target.busybox or ''), path=quote(path), ), as_root=True, @@ -360,18 +362,15 @@ def _mount_v2_controllers(target: LinuxTarget): @contextmanager -def _mount_v1_controllers(target: LinuxTarget, controllers: Set[str]): +def _mount_v1_controllers(target: LinuxTarget, controllers: Set[str]) -> Generator[Dict[str, str], None, None]: """ Mounts the V1 split CGroup controller hierarchies. :param target: Interface to target device. - :type target: Target :param controllers: The names of the CGroup controllers required to be mounted. - :type controllers: Set[str] :yield: A dictionary mapping CGroup controller names to the paths that they're currently mounted at. - :rtype: Dict[str,str] :raises TargetStableError: Occurs in the case where the root directory of a requested CGroup V1 Controller hierarchy is unable to be created up on the target system. @@ -385,7 +384,7 @@ def _mount_controller(controller): try: target.execute( "{busybox} mount -t cgroup -o {controller} none {path}".format( - busybox=quote(target.busybox), + busybox=quote(target.busybox or ''), controller=quote(controller), path=quote(path), ), @@ -395,7 +394,7 @@ def _mount_controller(controller): finally: target.execute( "{busybox} umount {path}".format( - busybox=quote(target.busybox), + busybox=quote(target.busybox or ''), path=quote(path), ), as_root=True, @@ -409,17 +408,15 @@ def _mount_controller(controller): def _validate_requested_hierarchy( - requested_controllers: Set[str], available_controllers: Dict -): + requested_controllers: Set[str], available_controllers: Dict[str, Any] +) -> None: """ Validates that the requested hierarchy is valid using the controllers available on the target system. :param requested_controllers: A set of ``str``, representing the controllers that are requested to be used in the user defined hierarchy. - :type requested_controllers: Set[str] :param available_controllers: A dictionary where the primary keys represent the available CGroup controllers on the target system. - :type available_controllers: Dict :raises TargetStableError: Occurs in the case where the requested CGroup hierarchy is unable to be set up on the target system. @@ -428,7 +425,7 @@ def _validate_requested_hierarchy( # Will determine if there are any controllers present within the requested controllers # and not within the available controllers - diff = set(requested_controllers) - available_controllers.keys() + diff: Set[str] = set(requested_controllers) - available_controllers.keys() if diff: raise TargetStableError( @@ -441,25 +438,21 @@ class _CGroupBase(ABC): The abstract base class that all CGroup class types' subclass. :param name: The name assigned to the CGroup. Used to identify the CGroup and define the CGroup directory name. - :type name: str :param parent_path: The path to the parent CGroup this CGroup is a child of. - :type parent_path: str :param active_controllers: A dictionary of CGroup controller name keys to dictionary value mappings, where the secondary dictionary contains a mapping between a specific 'attribute' of the aforementioned controller and a value for which that controller interface file should be set to. - :type active_controllers: Dict[str, Dict[str, Union[str,int]]] :param target: Interface to target device. - :type target: Target """ def __init__( self, name: str, parent_path: str, - active_controllers: Dict[str, Dict[str, str]], + active_controllers: ControllerDict, target: LinuxTarget, ): self.name = name @@ -468,77 +461,71 @@ def __init__( self._parent_path = parent_path @property - def group_path(self): + def group_path(self) -> str: return self.target.path.join(self._parent_path, self.name) def _set_controller_attribute( self, controller: str, attribute: str, value: Union[int, str], verify=False - ): + ) -> None: """ Writes the specified ``value`` into the interface file specified by the ``controller`` and ``attribute`` parameters. In the case where no ``controller`` name is specified, the ``attribute`` argument is assumed to be the name of the interface file to write to. :param controller: The controller we want to select. - :type controller: str :param attribute: The specific attribute of the controller we want to alter. - :type attribute: str :param value: The value we want to write to the specified interface file. - :type value: str :param verify: Whether we want to verify that the value is indeed written to the interface file, defaults to ``False``. - :type verify: bool, optional """ str_value = str(value) # Some CGroup interface files don't have a controller name prefix, we accommodate that here. - interface_file = controller + "." + attribute if controller else attribute + interface_file: str = controller + "." + attribute if controller else attribute - full_path = self.target.path.join(self.group_path, interface_file) + full_path: str = self.target.path.join(self.group_path, interface_file) self.target.write_value(full_path, str_value, verify=verify) - def _create_directory(self, path: str): + def _create_directory(self, path: str) -> None: """ Creates a new directory at the given path, creating the parent directories if required. If the directory already exists, no exception is thrown. :param path: Path to directory to be created. - :type path: str """ self.target.makedirs(path, as_root=True) - def _delete_directory(self, path: str): + def _delete_directory(self, path: str) -> None: """ Removes the directory at the given path. :param path: Path to the directory to be removed. - :type path: str """ # In this context we can't use the target.remove method since that # tries to delete the interface/controller files as well which isn't needed nor permitted. self.target.execute( "{busybox} rmdir -- {path}".format( - busybox=quote(self.target.busybox), path=quote(path) + busybox=quote(self.target.busybox or ''), path=quote(path) ), as_root=True, ) - def _add_process(self, pid: Union[str, int]): + def _add_process(self, pid: Union[str, int]) -> Optional[TargetStableError]: """ Adds the process associated with the ``pid`` to the CGroup, only if the process is not already a member of the CGroup. :param pid: The PID of the process to be added to the CGroup. - :type pid: Union[str,int] """ if not self.target.file_exists(filepath="/proc/{pid}/status".format(pid=pid)): + # FIXME - is this return of the error intentional or was it meant to be raised return TargetStableError( "The Process ID: {pid} does not exists.".format(pid=pid) ) @@ -558,26 +545,25 @@ def _add_process(self, pid: Union[str, int]): else: if str(pid) not in member_processes: self._set_controller_attribute("cgroup", "procs", pid) + return None - def _get_pid_from_tid(self, tid: int): + def _get_pid_from_tid(self, tid: int) -> int: """ Retrieves the ``pid`` (Process ID) that the ``tid`` (Thread ID) is a part of. :param tid: The Thread ID of the thread to be added to the CGroup. - :type tid: int :return: The ``pid`` (Process ID) associated with the ``tid`` (Thread ID). - :rtype: int """ - status = _read_lines( + status: List[str] = _read_lines( target=self.target, path="/proc/{tid}/status".format(tid=tid) ) for line in status: # the Tgid entry contains the thread group ID, which is the PID of # the process this thread belongs to. - match = re.match(r"\s*Tgid:\s*(\d+)\s*", line) + match: Optional[Match] = re.match(r"\s*Tgid:\s*(\d+)\s*", line) if match: - pid = match.group(1) + pid: str = match.group(1) break else: raise TargetStableError( @@ -587,7 +573,7 @@ def _get_pid_from_tid(self, tid: int): return int(pid) @abstractmethod - def _add_thread(self, tid: int, threaded_domain): + def _add_thread(self, tid: int, threaded_domain: Union['ResponseTree', '_TreeBase']): """ Ensures all sub-classes have the ability to add threads to their CGroups where their differences dont allow for a common approach. @@ -595,7 +581,7 @@ def _add_thread(self, tid: int, threaded_domain): pass @abstractmethod - def _init_cgroup(self): + def _init_cgroup(self) -> None: """ Ensures all sub-classes are able to initialise their respective CGroup directories as per defined by their user configurations. @@ -622,33 +608,27 @@ class _CGroupV2(_CGroupBase): A Class representing a CGroup directory within a CGroup V2 hierarchy. :param name: The name assigned to the CGroup. Used to identify the CGroup and define the CGroup folder name. - :type name: str :param parent_path: The path to the parent CGroup this CGroup is a child of. - :type parent_path: str :param active_controllers: A dictionary of controller name keys to dictionary value mappings, where the secondary dictionary contains a mapping between a specific 'attribute' of the aforementioned controller and a value for which that controller interface file should be set to. - :type active_controllers: Dict[str, Dict[str, Union[str,int]]] :param subtree_controllers: The controllers that should be delegated to the subtree. - :type subtree_controllers: Set[str] :param is_threaded: Whether the CGroup type is threaded, enables thread level granularity for the CGroup directory and its subtree. - :type is_threaded: bool :param target: Interface to target device. - :type target: Target """ def __init__( self, name: str, parent_path: str, - active_controllers: Dict[str, Dict[str, str]], - subtree_controllers: set, + active_controllers: ControllerDict, + subtree_controllers: set[str], is_threaded: bool, target: LinuxTarget, ): @@ -686,7 +666,7 @@ def __enter__(self): def __exit__(self, *exc): self._delete_directory(path=self.group_path) - def _init_cgroup(self): + def _init_cgroup(self) -> None: """ Performs the required steps in order to initialize the CGroup to the user specified configuration: @@ -727,7 +707,7 @@ def _init_cgroup(self): value="+{cont}".format(cont=controller), ) - def _add_thread(self, tid: int, threaded_domain): + def _add_thread(self, tid: int, threaded_domain: Union['ResponseTree', '_TreeBase']) -> None: """ Attempts to add the thread associated with ``tid`` to the CGroup. Due to the requirements imposed by the kernel regarding thread management within a V2 CGroup hierarchy, @@ -737,17 +717,16 @@ def _add_thread(self, tid: int, threaded_domain): across the entire subtree. :param tid: The TID (Thread ID) of the thread to be added to the CGroup. - :type tid: int :param threaded_domain: The :class:`ResponseTree` object representing the threaded domain of the threaded CGroup subtree. The process will be added to all the CGroups that the :class:`ResponseTree` represent. - :type threaded_domain: :class:`ResponseTree` + """ - pid_of_tid = self._get_pid_from_tid(tid=tid) + pid_of_tid: int = self._get_pid_from_tid(tid=tid) - for low_level in threaded_domain.low_levels.values(): + for low_level in cast(ResponseTree, threaded_domain).low_levels.values(): low_level._add_process(pid_of_tid) self._set_controller_attribute( @@ -762,19 +741,16 @@ class _CGroupV2Root(_CGroupV2): CGroup hierarchy. :param mount_point: The path on which the root of the CGroup V2 hierarchy is mounted on. - :type mount_point: str :param subtree_controllers: The controllers that should be delegated to the subtree. - :type subtree_controllers: Set[str] :param target: Interface to target device. - :type target: Target """ @classmethod def _v2_controller_translation( - cls, controllers: Dict[str, Dict[str, Union[str, int]]] - ): + cls, controllers: ControllerDict + ) -> ControllerDict: """ Given the new controller names within V2, rename the controllers to provide CGroupV2 compatibility. At this point in time, the ``blkio`` controller has been renamed to ``io`` in V2, while the V2 ``cpu`` controller @@ -783,7 +759,6 @@ def _v2_controller_translation( :param controllers: A dictionary of controller name keys to dictionary value mappings, where the secondary dictionary contains a mapping between the ``version`` and `mount_point`` keys and their respectively obtained values. - :rtype: Dict[str, Dict[str,Union[str,int]]] :raises TargetStableError: In the case where the the ``cpu`` and ``cpuacct`` CGroup controllers are in use under different CGroup version hierarchies. @@ -791,10 +766,9 @@ def _v2_controller_translation( :raises TargetStableError: In the case where either ``cpu`` / ``cpuacct`` controller is not enabled on the target device. :return: The amended ``controllers`` dictionary with the updated names. - :rtype: Dict[str, Dict[str, Union[str,int]]] """ - translation = {} + translation: ControllerDict = {} if "blkio" in controllers: translation["io"] = controllers["blkio"] @@ -824,41 +798,39 @@ def _v2_controller_translation( } @classmethod - def _get_delegated_sub_path(cls, delegated_pid: int, target: LinuxTarget): + def _get_delegated_sub_path(cls, delegated_pid: int, target: LinuxTarget) -> Optional[str]: """ Returns the relative sub-path the delegated root of the V2 hierarchy is mounted on, via the parsing of the /proc//cgroup file of the delegated process associated with ``delegated_pid``. :param delegated_pid: The Main PID of the transient service unit we requested delegation for. - :type delegated_pid: int :param target: Interface to target device. - :type target: Target :return: The sub-path to the delegate root of the V2 CGroup hierarchy. - :rtype: str """ - relative_delegated_mount_paths = _read_lines( + relative_delegated_mount_paths: List[str] = _read_lines( target=target, path="/proc/{pid}/cgroup".format(pid=delegated_pid) ) # Following Regex matches the line that contains the relative sub path. - REL_PATH_REGEX = re.compile(r"0::\/(?P.+)") + REL_PATH_REGEX: Pattern[str] = re.compile(r"0::\/(?P.+)") for mount_path in relative_delegated_mount_paths: - m = REL_PATH_REGEX.match(mount_path) + m: Optional[Match[str]] = REL_PATH_REGEX.match(mount_path) if m: return m.group("path") else: raise TargetStableError( "A V2 CGroup hierarchy was not delegated by systemd." ) + return None @classmethod def _get_available_controllers( - cls, controllers: Dict[str, Dict[str, Union[str, int]]] - ): + cls, controllers: ControllerDict + ) -> ControllerDict: """ Returns the CGroup controllers that are currently not in use on the target device, which can be taken control over and used in a manually mounted V2 hierarchy. @@ -867,17 +839,15 @@ def _get_available_controllers( :param controllers: A dictionary of CGroup controller name keys to dictionary value mappings, where the secondary dictionary contains a mapping between various CGroup controller configuration keys and their respectively obtained values for the respective CGroup controllers. - :rtype: Dict[str, Dict[str,Union[str,int]]] :raises TargetStableError: Occurs in the case where a V2 hierarchy is already mounted on the target device. We want to bail out in this case. :return: The ``controllers`` Dict filtered to just those controllers which are free/un-used. - :rtype: Dict[str, Dict[str, Union[str,int]]] """ # Filters the controllers dict to entries where the version is == 2. - mounted_v2_controllers = { + mounted_v2_controllers: Set[str] = { controller for controller, configuration in controllers.items() if (configuration.get("version") == 2) @@ -896,8 +866,8 @@ def _get_available_controllers( @classmethod def _path_to_delegated_root( - cls, controllers: Dict[str, Dict[str, Union[int, str]]], sub_path: str - ): + cls, controllers: ControllerDict, sub_path: str + ) -> str: """ Return the full path to the delegated root. This occurs in 2 stages: @@ -911,19 +881,16 @@ def _path_to_delegated_root( :param controllers: A Dictionary of currently mounted controller name keys to Dictionary value mappings, where the secondary dictionary contains a mapping between various CGroup controller configuration keys and their respectively obtained values for the respective CGroup controllers. - :type controllers: Dict[str, Dict[str, Union[str,int]]] :param sub_path: The relative subpath to the delegated root hierarchy. - :type sub_path: str :raises TargetStableError: Occurs in the case where no V2 controllers are active on the target. :return: A full path to the delegated root of the V2 CGroup hierarchy. - :rtype: str """ # Filter out non v2 controller mounts and append the "mount_point" to a set - v2_mount_point = { + v2_mount_point: Set[Union[str, int]] = { configuration["mount_point"] for configuration in controllers.values() if configuration.get("version") == 2 @@ -934,36 +901,32 @@ def _path_to_delegated_root( ) else: # Since there can only be a single V2 hierarchy (ignoring bind mounts), this should be totally legal. - mount_path_to_unified_hierarchy = v2_mount_point.pop() - return str(os.path.join(mount_path_to_unified_hierarchy, sub_path)) + mount_path_to_unified_hierarchy: Union[str, int] = v2_mount_point.pop() + return str(os.path.join(cast(str, mount_path_to_unified_hierarchy), sub_path)) @classmethod @contextmanager def _systemd_offline_mount( cls, target: LinuxTarget, - all_controllers: Dict[str, Dict[str, Union[str, int]]], + all_controllers: ControllerDict, requested_controllers: Set[str], - ): + ) -> Generator[str, None, None]: """ Manually mounts the V2 hierarchy on the target device. Occurs in the absence of systemd. :param target: Interface to target device. - :type target: Target :param all_controllers: A Dictionary of currently mounted controller name keys to Dictionary value mappings, where the secondary dictionary contains a mapping between various CGroup controller configuration keys and their respectively obtained values for the respective CGroup controllers. - :type controllers: Dict[str, Dict[str, Union[str,int]]] :param requested_controllers: The set of controllers required to mount the requested hierarchy. - :type requested_controllers: Set[str] :yield: The path to the root mount point of the unified V2 hierarchy. - :rtype: str """ - unused_controllers = _CGroupV2Root._get_available_controllers( + unused_controllers: ControllerDict = _CGroupV2Root._get_available_controllers( controllers=all_controllers ) _validate_requested_hierarchy( @@ -979,36 +942,32 @@ def _systemd_offline_mount( def _systemd_online_setup( cls, target: LinuxTarget, - all_controllers: Dict[str, Dict[str, int]], + all_controllers: ControllerDict, requested_controllers: Set[str], - ): + ) -> Generator[str, None, None]: """ Sets up the required V2 hierarchy on the target device. Occurs in the presence of systemd. :param target: Interface to target device. - :type target: Target :param all_controllers: A Dictionary of currently mounted CGroup controller name keys to dictionary value mappings, where the secondary dictionary contains a mapping between various CGroup controller configuration keys and their respectively obtained values for the respective CGroup controllers. - :type all_controllers: Dict[str, Dict[str, Union[str,int]]] :param requested_controllers: The set of controllers required to mount the requested hierarchy. - :type requested_controllers: Set[str] :yield: The path to the root of the delegated V2 CGroup hierarchy. - :rtype: str """ with _request_delegation(target=target) as main_pid: - delegated_sub_path = _CGroupV2Root._get_delegated_sub_path( + delegated_sub_path: Optional[str] = _CGroupV2Root._get_delegated_sub_path( delegated_pid=main_pid, target=target ) - delegated_path = _CGroupV2Root._path_to_delegated_root( + delegated_path: str = _CGroupV2Root._path_to_delegated_root( controllers=all_controllers, - sub_path=delegated_sub_path, + sub_path=cast(str, delegated_sub_path), ) - delegated_controllers_path = "{path}/cgroup.controllers".format( + delegated_controllers_path: str = "{path}/cgroup.controllers".format( path=delegated_path ) @@ -1018,7 +977,7 @@ def _systemd_online_setup( # by _read_file and splitting said element (str) using the white space character # as the delimiter. # (The _validate_requested_hierarchy requires the available_controllers argument to be a dict, necessitating this dict structure.) - delegated_controllers = { + delegated_controllers: Dict[str, None] = { controller: None for controller in _read_lines( target=target, path=delegated_controllers_path @@ -1033,28 +992,25 @@ def _systemd_online_setup( @classmethod @contextmanager - def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]): + def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) -> Generator[str, None, None]: """ Mounts/Sets-up a V2 hierarchy on the target device, covering contexts where systemd is both present and absent. :param target: Interface to target device. - :type target: Target :param requested_controllers: The set of controllers required to mount the requested hierarchy. - :type requested_controllers: Set[str] :yield: A path to the root of the V2 hierarchy that has been mounted/delegated for the user. - :rtype: str """ - systemd_online = _is_systemd_online(target=target) - controllers = _CGroupV2Root._v2_controller_translation( + systemd_online: bool = _is_systemd_online(target=target) + controllers: ControllerDict = _CGroupV2Root._v2_controller_translation( _get_cgroup_controllers(target=target) ) if systemd_online: - cm = _CGroupV2Root._systemd_online_setup( + cm: _GeneratorContextManager[str] = _CGroupV2Root._systemd_online_setup( target=target, all_controllers=controllers, requested_controllers=requested_controllers, @@ -1074,7 +1030,7 @@ def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) def __init__( self, mount_point: str, - subtree_controllers: set, + subtree_controllers: set[str], target: LinuxTarget, ): @@ -1088,7 +1044,7 @@ def __init__( is_threaded=False, target=target, ) - self.target = target + self.target: LinuxTarget = target def __enter__(self): """ @@ -1113,7 +1069,7 @@ def __enter__(self): def __exit__(self, *exc): pass - def _init_root_cgroup(self): + def _init_root_cgroup(self) -> None: """ Performs the required actions in order to initialise a Root V2 CGroup. In the case where systemd is active, there is a required need to create a leaf CGroup from the Root, where the PIDs @@ -1124,11 +1080,11 @@ def _init_root_cgroup(self): if _is_systemd_online(target=self.target): # Create the leaf CGroup directory - group_name = "devlib-" + str(uuid4().hex) - full_path = self.target.path.join(self.group_path, group_name) + group_name: str = "devlib-" + str(uuid4().hex) + full_path: str = self.target.path.join(self.group_path, group_name) self._create_directory(full_path) - delegated_pids = _read_lines( + delegated_pids: List[str] = _read_lines( target=self.target, path="{path}/cgroup.procs".format(path=self.group_path), ) @@ -1155,19 +1111,14 @@ class _CGroupV1(_CGroupBase): A Class representing a CGroup folder within a CGroup V1 hierarchy. :param name: The name assigned to the CGroup. Used to identify the CGroup and define the CGroup folder name. - :type name: str :param parent_path: The path to the parent CGroup this CGroup is a child of. - :type parent_path: str :param active_controllers: A dictionary of controller name keys to dictionary value mappings, where the secondary dictionary contains a mapping between a specific 'attribute' of the aforementioned controller and a value for which that controller interface should be set to. - :type active_controllers: Dict[str, Dict[str, Union[str,int]]] - :param target: Interface to target device. - :type target: Target """ def __enter__(self): @@ -1179,7 +1130,6 @@ def __enter__(self): :raises TargetStableError: If an exception occurs within the :meth:`_init_cgroup` method call. :return: An object reference to itself. - :rtype: :class:`_CGroupV1` """ self._create_directory(self.group_path) @@ -1209,7 +1159,7 @@ def _init_cgroup(self): controller=controller, attribute=attr, value=val, verify=True ) - def _add_thread(self, tid: int, threaded_domain): + def _add_thread(self, tid: int, threaded_domain: Union['ResponseTree', '_TreeBase']) -> None: """ Adds the thread associated with ``tid`` to the CGroup. While thread level management suffers from no restrictions within a V1 hierarchy, @@ -1220,17 +1170,15 @@ def _add_thread(self, tid: int, threaded_domain): granularity across the entire of the threaded subtree. :param tid: The TID of the thread to be added to the CGroup - :type tid: int :param threaded_domain: The :class:`ResponseTree` object representing the threaded domain of the threaded CGroup subtree. The process will be added to all the CGroups that the :class:`ResponseTree` represents. - :type threaded_domain: :class:`ResponseTree` """ - pid_of_tid = self._get_pid_from_tid(tid=tid) + pid_of_tid: int = self._get_pid_from_tid(tid=tid) - for low_level in threaded_domain.low_levels.values(): + for low_level in cast(ResponseTree, threaded_domain).low_levels.values(): low_level._add_process(pid_of_tid) self._set_controller_attribute("", "tasks", tid) @@ -1243,19 +1191,17 @@ class _CGroupV1Root(_CGroupV1): CGroup hierarchy. :param mount_point: The path to which the root of the CGroup V1 controller hierarchy is mounted on. - :type mount_point: str :param target: Interface to target device. - :type target: Target """ @classmethod def _get_delegated_paths( cls, - controllers: Dict[str, Dict[str, Union[str, int]]], + controllers: ControllerDict, delegated_pid: int, target: LinuxTarget, - ): + ) -> Dict[str, str]: """ Returns the relative sub-paths the delegated roots of the V1 hierarchies, via the parsing of the /proc//cgroup file of the delegated PID. @@ -1263,21 +1209,17 @@ def _get_delegated_paths( :param controllers: A dictionary of currently mounted CGroup controller name keys to dictionary value mappings, where the secondary dictionary contains a mapping between various CGroup controller configuration keys and their respectively obtained values for the respective CGroup controllers. - :type controllers: Dict[str, Dict[str, Union[str,int]]] :param delegated_pid: The Main PID of the transient service unit we request delegation for. - :type delegated_pid: int :param target: Interface to target device. - :type target: Target :raises TargetStableError: Occurs in the case where no V1 controllers have been delegated. :return: A dictionary mapping CGroup controllers to their respective delegated root paths. - :rtype: Dict[str, str] """ - delegated_mount_paths = _read_lines( + delegated_mount_paths: List[str] = _read_lines( target=target, path="/proc/{pid}/cgroup".format(pid=delegated_pid) ) @@ -1288,22 +1230,22 @@ def _get_delegated_paths( # # The regex is structured to only match V1 controller hierarchies. - REL_PATH_REGEX = re.compile( + REL_PATH_REGEX: Pattern[str] = re.compile( r"\d+:(?P.+):\/(?P.*)" ) - delegated_controllers = {} + delegated_controllers: Dict[str, str] = {} for mount_path in delegated_mount_paths: - regex_match = REL_PATH_REGEX.match(mount_path) + regex_match: Optional[Match[str]] = REL_PATH_REGEX.match(mount_path) if regex_match: - con = regex_match.group("controllers") - path = regex_match.group("path_to_delegated_service_root") + con: str = regex_match.group("controllers") + path: str = regex_match.group("path_to_delegated_service_root") # Multiple v1 controllers can be co-mounted on a single folder hierarchy. - co_mounted_controllers = con.strip().split(",") + co_mounted_controllers: List[str] = con.strip().split(",") for controller in co_mounted_controllers: try: - configuration = controllers[controller] + configuration: Dict[str, Union[str, int]] = controllers[controller] except KeyError: pass else: @@ -1323,28 +1265,24 @@ def _get_delegated_paths( def _systemd_offline_mount( cls, requested_controllers: Set[str], - all_controllers: Dict[str, Dict[str, Union[str, int]]], + all_controllers: ControllerDict, target: LinuxTarget, ): """ Manually mounts the V1 split hierarchy on the target device. Occurs in the absence of systemd. :param requested_controllers: The set of controllers required to mount the requested hierarchy. - :type requested_controllers: Set[str] :param all_controllers: A Dictionary of currently mounted controller name keys to Dictionary value mappings, where the secondary dictionary contains a mapping between various CGroup controller configuration keys and their respectively obtained values for the respective CGroup controllers. - :type all_controllers: Dict[str, Dict[str, Union[str,int]]] :param target: Interface to target device. - :type target: Target :yield: A dictionary mapping CGroup controller names to their respective mount points. - :rtype: Dict[str,str] """ - available_controllers = _CGroupV1Root._get_available_v1_controllers( + available_controllers: ControllerDict = _CGroupV1Root._get_available_v1_controllers( controllers=all_controllers ) _validate_requested_hierarchy( @@ -1358,10 +1296,13 @@ def _systemd_offline_mount( @classmethod def _get_available_v1_controllers( - cls, controllers: Dict[str, Dict[str, Union[int, str]]] - ): + cls, controllers: ControllerDict + ) -> ControllerDict: + """ + helper function to get the available v1 controllers + """ - unused_controllers = { + unused_controllers: ControllerDict = { controller: configuration for controller, configuration in controllers.items() if configuration.get("version") is None @@ -1378,28 +1319,24 @@ def _systemd_online_setup( cls, target: LinuxTarget, requested_controllers: Set[str], - all_controllers: Dict[str, Dict[str, str]], - ): + all_controllers: ControllerDict, + ) -> Generator[Dict[str, str], None, None]: """ Sets up the required V1 hierarchy on the target device. Occurs in the presence of systemd. :param target: Interface to target device. - :type target: Target :param requested_controllers: The set of controllers required to mount the requested hierarchy. - :type requested_controllers: Set[str] :param all_controllers: A Dictionary of currently mounted controller name keys to dictionary value mappings, where the secondary dictionary contains a mapping between various CGroup controller configuration keys and their respectively obtained values for the respective CGroup controllers. - :type all_controllers: Dict[str, Dict[str, Union[str,int]]] :yield: A Dict[str, str] consisting of controller name keys mapped to their respective mount points. - :rtype: Dict[str, str] """ with _request_delegation(target) as pid: - delegated_controllers = _CGroupV1Root._get_delegated_paths( + delegated_controllers: Dict[str, str] = _CGroupV1Root._get_delegated_paths( controllers=all_controllers, delegated_pid=pid, target=target, @@ -1413,7 +1350,7 @@ def _systemd_online_setup( @classmethod @contextmanager - def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]): + def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) -> Generator[Dict[str, str], None, None]: """ A context manager which Mounts/Sets-up a V1 split hierarchy on the target device, covering contexts where systemd is both present and absent. This context manager Mounts/Sets-up a split V1 hierarchy (if possible) @@ -1421,26 +1358,23 @@ def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) the target device to the state before the mount/set-up occurred. :param target: Interface to target device. - :type target: Target :param requested_controllers: The set of controllers required to mount the requested hierarchy. - :type requested_controllers: Set[str] :yield: A dictionary mapping controller name to the paths where the controllers are mounted on, used to build the user requested V1 hierarchy. - :rtype: dict[str,str] """ - systemd_online = _is_systemd_online(target=target) - controllers = _get_cgroup_controllers(target=target) + systemd_online: bool = _is_systemd_online(target=target) + controllers: ControllerDict = _get_cgroup_controllers(target=target) if systemd_online: - cm = _CGroupV1Root._systemd_online_setup( + cm: _GeneratorContextManager[Dict[str, str]] = _CGroupV1Root._systemd_online_setup( target=target, requested_controllers=requested_controllers, all_controllers=controllers, ) - with cm as controllers: - yield controllers + with cm as controllers_temp: + yield controllers_temp else: cm = _CGroupV1Root._systemd_offline_mount( @@ -1448,8 +1382,8 @@ def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) requested_controllers=requested_controllers, all_controllers=controllers, ) - with cm as controllers: - yield controllers + with cm as controllers_temp: + yield controllers_temp def __init__(self, mount_point: str, target: LinuxTarget): @@ -1482,10 +1416,8 @@ class _TreeBase(ABC): The abstract base class that all tree class types' subclass. :param name: The name assigned to the tree node. - :type name: str :param is_threaded: Whether the node is threaded or not. - :type is_threaded: bool """ def __init__(self, name: str, is_threaded: bool): @@ -1495,14 +1427,14 @@ def __init__(self, name: str, is_threaded: bool): # Propagates Threaded Property to # sub-tree. - def make_threaded(grp): + def make_threaded(grp: '_TreeBase'): grp.is_threaded = True for child in grp._children_list: make_threaded(child) # Propagates the Threaded domain # to sub-tree. - def set_domain(grp): + def set_domain(grp: '_TreeBase'): grp.threaded_domain = domain for child in grp._children_list: set_domain(child) @@ -1510,14 +1442,18 @@ def set_domain(grp): if is_threaded: make_threaded(self) else: - domain = self - if any([child.is_threaded for child in self._children_list]): - for child in self._children_list: - make_threaded(child) - set_domain(child) + domain: '_TreeBase' = self + if self._children_list: + if any([child.is_threaded for child in self._children_list]): + for child in self._children_list: + make_threaded(child) + set_domain(child) @property - def is_threaded_domain(self): + def is_threaded_domain(self) -> bool: + """ + check if the is_threaded property is set in the domain + """ return ( True if any([child.is_threaded for child in self._children_list]) @@ -1527,7 +1463,10 @@ def is_threaded_domain(self): @property @memoized - def group_type(self): + def group_type(self) -> str: + """ + get the type of the group + """ if self.is_threaded_domain: return "threaded domain" elif self.is_threaded: @@ -1540,10 +1479,8 @@ def __str__(self, level=0): Returns a string representation of the tree hierarchy, used for visualization and debugging. :param level: The current depth of the tree, defaults to 0. - :type level: int, optional :return: String formatted output, displaying the hierarchical structure of the tree. - :rtype: str """ TAB = "\t" @@ -1563,7 +1500,7 @@ def __str__(self, level=0): @property @abstractmethod - def _node_information(self): + def _node_information(self) -> str: """ Returns a formatted string displaying the information the :class:`_TreeBase` object represents. """ @@ -1571,7 +1508,7 @@ def _node_information(self): @property @abstractmethod - def _children_list(self): + def _children_list(self) -> List['_TreeBase']: """ Returns List[:class:`_TreeBase`]. """ @@ -1585,36 +1522,32 @@ class RequestTree(_TreeBase): required by ensuring V2 semantic equivalence is maintained within the context of setting up a V1 hierarchy. :param name: Name assigned to the user defined :class:`RequestTree` object. - :type name: str :param children: A list of :class:`RequestTree` objects representing the children the object is a hierarchical parent to, defaults to ``None``. - :type children: List[:class:`RequestTree`], optional :param controllers: A Dictionary of controller name keys to dictionary value mappings, where the secondary dictionary contains a mapping between controller specific attributes and their respective to be assigned values, , defaults to ``None``. - :type controllers: Dict[str, Dict[str, Union[str,int]]], optional :param threaded: defines whether the object will represent a CGroup capable of managing threads, defaults to ``False``. - :type threaded: bool, optional """ def __init__( self, name: str, - children: Union[list, None] = None, - controllers: Union[Dict[str, Dict[str, Any]], None] = None, - threaded=False, + children: Union[list['RequestTree'], None] = None, + controllers: Optional[ControllerDict] = None, + threaded: bool = False, ): self.children = children or [] self.controllers = controllers or {} super().__init__(name=name, is_threaded=threaded) @property - def _node_information(self): + def _node_information(self) -> str: # Returns Requests Tree Node Information. - active_controllers = [ + active_controllers: List[str] = [ "({controller}) {config}".format( controller=controller, config=configuration ) @@ -1627,8 +1560,10 @@ def _node_information(self): @property @memoized - def _all_controllers(self): - # Returns a set of all the controllers that are active in that subtree, including its own. + def _all_controllers(self) -> Set[str]: + """ + Returns a set of all the controllers that are active in that subtree, including its own. + """ return set( itertools.chain( self.controllers.keys(), @@ -1639,8 +1574,10 @@ def _all_controllers(self): ) @property - def _subtree_controllers(self): - # Returns a set of all the controllers that are active in that subtree, excluding its own. + def _subtree_controllers(self) -> Set[str]: + """ + Returns a set of all the controllers that are active in that subtree, excluding its own. + """ return set( itertools.chain.from_iterable( map(lambda child: child._all_controllers, self.children) @@ -1648,11 +1585,11 @@ def _subtree_controllers(self): ) @property - def _children_list(self): + def _children_list(self) -> List['_TreeBase']: return list(self.children) @contextmanager - def setup_hierarchy(self, version: int, target: LinuxTarget): + def setup_hierarchy(self, version: int, target: LinuxTarget) -> Generator['ResponseTree', None, None]: """ A context manager which processes the user defined hierarchy and sets-up said hierarchy on the ``target`` device. Uses an internal exit stack to the handle the entering and safe exiting of the lower level @@ -1662,10 +1599,8 @@ def setup_hierarchy(self, version: int, target: LinuxTarget): which the user will interact with and can inspect. :param version: The version of the CGroup hierarchy to be set up on the Target device. - :type version: int :param target: Interface to target device. - :type target: Target :raises TargetStableError: Occurs when the version argument is neither ``1`` or ``2``; the only two versions of CGroups currently available. @@ -1674,15 +1609,16 @@ def setup_hierarchy(self, version: int, target: LinuxTarget): """ with ExitStack() as exit_stack: + make_groups: Callable if version == 1: # Returns a {controller_name: controller_mount_point} dict - controller_paths = exit_stack.enter_context( + controller_paths: Dict[str, str] = exit_stack.enter_context( _CGroupV1Root._mount_filesystem( target=target, requested_controllers=self._all_controllers ) ) # Mounts the Roots Controller Parents. - root_parents = { + root_parents: Union[Dict[str, _CGroupV1], _CGroupV2] = { controller: _CGroupV1Root( mount_point=mount_path, target=target, @@ -1691,7 +1627,7 @@ def setup_hierarchy(self, version: int, target: LinuxTarget): if controller in self._all_controllers } - def make_groups(request: RequestTree, parents: Dict[str, _CGroupBase]): + def make_groups_v1(request: RequestTree, parents: Dict[str, _CGroupBase]): """ Defines and instantiates the low-level :class:`_CGroupV1` objects as per defined by the configuration of the ``request`` :class:`RequestTree` object. @@ -1700,10 +1636,8 @@ def make_groups(request: RequestTree, parents: Dict[str, _CGroupBase]): created 'under' the suitable parent CGroup directory. :param request: The :class:`RequestTree` object that'll define the required :class:`_CGroupV1` objects it represents. - :type request: :class:`RequestTree` :param parents: The Dictionary mapping that maps CGroup controller names to their leaf CGroup directory. - :type parents: Dict[str, :class:`_CGroupBase`] :return: A tuple ``(request_defined_cgroups, all_cgroups, parents)`` where the first element defines the dictionary mapping the controller names to the :class:`_CGroupV1` objects created directly @@ -1715,10 +1649,9 @@ def make_groups(request: RequestTree, parents: Dict[str, _CGroupBase]): the low-level :class:`_CGroupV1` objects a particular :class:`RequestTree` instance indirectly defines given its parents and the :class:`_CGroupV1` objects it passes to it children as potential suitable parents are the same. - :rtype: tuple(Dict[str,:class:`_CGroupV1`], Dict[str,:class:`_CGroupV1`], Dict[str,:class:`_CGroupV1`]) """ - request_defined_cgroups = { + request_defined_cgroups: Dict[str, _CGroupV1] = { controller: _CGroupV1( name=request.name, parent_path=parents[controller].group_path, @@ -1730,9 +1663,11 @@ def make_groups(request: RequestTree, parents: Dict[str, _CGroupBase]): # Parent dict updated to include the newly created leaf CGroups. parents = {**parents, **request_defined_cgroups} - all_cgroups = parents + all_cgroups: Dict[str, _CGroupBase] = parents return (request_defined_cgroups, all_cgroups, parents) + make_groups = make_groups_v1 + elif version == 2: # Returns a string representing the root of the V2 hierarchy @@ -1752,7 +1687,7 @@ def make_groups(request: RequestTree, parents: Dict[str, _CGroupBase]): # root CGroup setup defined within the __enter__ method. exit_stack.enter_context(root_parents) - def make_groups(request: RequestTree, parent: _CGroupV2): + def make_groups_v2(request: RequestTree, parent: _CGroupV2): """ Defines and instantiates the low-level :class:`_CGroupV2` object as per defined by the configuration of the ``request`` :class:`RequestTree` object. The parents of said :class:`_CGroupV2` object @@ -1760,10 +1695,8 @@ def make_groups(request: RequestTree, parent: _CGroupV2): :class:`_CGroupV2` object is created 'under' the suitable parent CGroup directory. :param request: The :class:`RequestTree` object that'll define the required :class:`_CGroupV2` object. - :type request: :class:`RequestTree` :param parents: The CGroup that'll be the parent of the :class:`_CGroupV2` object being defined. - :type parents: :class:`_CGroupV2` :return: A tuple ``(controllers_to_cgroup, controllers_to_cgroup, parent)`` where the first and second elements define a dictionary mapping of controller names as per defined by the :class:`RequestTree` object @@ -1773,7 +1706,6 @@ def make_groups(request: RequestTree, parent: _CGroupV2): hierarchical parent of the subsequent V2 CGroups to be created. Duplication is required in this case since both the paths the user defined V2 controllers are enabled at and the actual paths of the low-level implementation are the same as per the structure of the unified V2 hierarchy. - :rtype: tuple(Dict[str,:class:`_CGroupV2`],Dict[str,:class:`_CGroupV2`],:class:`_CGroupV2`) """ request_group = _CGroupV2( @@ -1787,13 +1719,15 @@ def make_groups(request: RequestTree, parent: _CGroupV2): # Creates a mapping between the enabled controllers within this CGroup to the low-level # _CGroupV2 object - controllers_to_cgroup = dict.fromkeys( + controllers_to_cgroup: Dict[str, _CGroupV2] = dict.fromkeys( request.controllers, request_group ) # Creating 'parent' variable for readability’s sake. parent = request_group return (controllers_to_cgroup, controllers_to_cgroup, parent) + make_groups = make_groups_v2 + else: raise TargetStableError( "A {version} version hierarchy cannot be mounted. Ensure requested hierarchy version is 1 or 2.".format( @@ -1802,9 +1736,9 @@ def make_groups(request: RequestTree, parent: _CGroupV2): ) # Create the Response Tree from the Request Tree. - response = self._create_response(root_parents, make_groups=make_groups) + response: 'ResponseTree' = self._create_response(root_parents, make_groups=make_groups) # Returns a list of all the Low-level _CGroupBase objects the response object represents in the right order - groups = response._all_nodes + groups: List[_CGroupBase] = response._all_nodes # Remove duplicates while preserving order. groups = sorted(set(groups), key=groups.index) # Enter the context for each object @@ -1813,7 +1747,7 @@ def make_groups(request: RequestTree, parent: _CGroupV2): yield response - def _create_response(self, low_level_parent, make_groups): + def _create_response(self, low_level_parent: Union[Dict[str, _CGroupV1], _CGroupV2], make_groups) -> 'ResponseTree': """ Creates the :class:`ResponseTree` object tree, using the appropriately defined :meth:`make_group` callable (defined as a local function internally within :meth:`setup_hierarchy`) alongside the ``low_level_parent`` object to create the low-level CGroups a particular :class:`RequestTree` object represents. @@ -1822,14 +1756,11 @@ def _create_response(self, low_level_parent, make_groups): :param low_level_parent: The parent/s to the CGroups to be created. In the context of setting up a V1 hierarchy, this will be a dictionary mapping controller names to :class:`_CGroupV1` objects; while in the case of V2, it'll be a solitary :class:`_CGroupV2` object. - :type low_level_parent: Dict[str,:class:`_CGroupV1`] | :class:`_CGroupV2` :param make_groups: The callable function definition used to create the low-level CGroup required. This callable is defined appropriately depending on the CGroup hierarchy version we require to set-up/mount. - :type make_groups: callable :return: The root of the :class:`ResponseTree` object tree. - :rtype: :class:`ResponseTree` """ user_visible_low_level_groups, low_level_groups, low_level_parent = make_groups( @@ -1862,24 +1793,19 @@ class ResponseTree(_TreeBase, collections.abc.Mapping): each :class:`ResponseTree` object represents and abstracts the low-level CGroups its respective :class:`RequestTree` object defines. :param name: Name assigned to the :class:`ResponseTree` object, mirrors the name defined to its respective :class:`RequestTree` Object. - :type name: str :param children: A dictionary that maps children names that this :class:`ResponseTree` object is a parent to and the respective :class:`ResponseTree` object the names represent. - :type children: dict[str,:class:`ResponseTree`] :param low_levels: A dictionary that maps CGroup controller names to the suitable low level CGroup this :class:`ResponseTree` abstracts. - :type low_levels: Dict[str, :class:`_CGroupBase`] :param user_low_levels: A dictionary that maps CGroup controller names to the suitable low level CGroup the :class:`RequestTree` object this class mirrors has specified. This is used within the context of a V1 user defined hierarchy in order to abstract the additional CGroups this class represents when trying to ensure V2 semantic equivalence. Done purely for cosmetic reasons. - :type user_low_levels: Dict[str, :class:`_CGroupBase`] :param is_threaded: Boolean flag representing whether or not this ResponseTree object represents a single threaded V2 CGroup or a collection of pseudo-threaded V1 CGroups. - :type is_threaded: bool """ def __init__( @@ -1896,7 +1822,7 @@ def __init__( super().__init__(name=name, is_threaded=is_threaded) @property - def _node_information(self): + def _node_information(self) -> str: # Returns a formatted string, displaying the enabled user-defined controllers and their paths # (alongside the type of CGroup the controller resides in). return ", ".join( @@ -1909,27 +1835,26 @@ def _node_information(self): ) @property - def _children_list(self): + def _children_list(self) -> List['_TreeBase']: # Children Objects are the values in our self.children dict. return list(self.children.values()) @property - def _all_nodes(self): + def _all_nodes(self) -> List[_CGroupBase]: return list( itertools.chain( self.low_levels.values(), itertools.chain.from_iterable( - map(lambda child: child._all_nodes, self.children.values()), + map(lambda child: cast(ResponseTree, child)._all_nodes, self.children.values()), ), ) ) - def add_process(self, pid: int): + def add_process(self, pid: int) -> None: """ Adds the process associated with ``pid`` to the low level CGroups this :class:`ResponseTree` object represents. :param pid: the PID of the process to be added to the low-level CGroups. - :type pid: int :raises TargetStableError: Occurs in the case where this object is a parent to non-threaded children. Ensures V2 hierarchy compatibility. @@ -1945,12 +1870,11 @@ def add_process(self, pid: int): ) ) - def add_thread(self, tid: int): + def add_thread(self, tid: int) -> None: """ Adds the thread associated with the ``tid`` to the low level CGroups this :class:`ResponseTree` object represents. :param tid: the TID of the thread to be added to the low-level CGroups. - :type tid: int :raises TargetStableError: Occurs in the case where this object is not threaded. Ensures V2 hierarchy compatibility. diff --git a/devlib/module/cooling.py b/devlib/module/cooling.py index 413360a7c..39c43aeb1 100644 --- a/devlib/module/cooling.py +++ b/devlib/module/cooling.py @@ -1,4 +1,4 @@ -# Copyright 2014-2015 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,50 +16,68 @@ from devlib.module import Module from devlib.utils.serial_port import open_serial_connection +from typing import TYPE_CHECKING, cast +from pexpect import fdpexpect +if TYPE_CHECKING: + from devlib.target import Target class MbedFanActiveCoolingModule(Module): - - name = 'mbed-fan' - timeout = 30 + """ + Module to control active cooling using fan + """ + name: str = 'mbed-fan' + timeout: int = 30 @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return True - def __init__(self, target, port='/dev/ttyACM0', baud=115200, fan_pin=0): + def __init__(self, target: 'Target', port: str = '/dev/ttyACM0', baud: int = 115200, fan_pin: int = 0): super(MbedFanActiveCoolingModule, self).__init__(target) self.port = port self.baud = baud self.fan_pin = fan_pin - def start(self): + def start(self) -> None: + """ + send motor start to fan + """ with open_serial_connection(timeout=self.timeout, port=self.port, baudrate=self.baud) as target: # pylint: disable=no-member - target.sendline('motor_{}_1'.format(self.fan_pin)) + cast(fdpexpect.fdspawn, target).sendline('motor_{}_1'.format(self.fan_pin)) - def stop(self): + def stop(self) -> None: + """ + send motor stop to fan + """ with open_serial_connection(timeout=self.timeout, port=self.port, baudrate=self.baud) as target: # pylint: disable=no-member - target.sendline('motor_{}_0'.format(self.fan_pin)) + cast(fdpexpect.fdspawn, target).sendline('motor_{}_0'.format(self.fan_pin)) class OdroidXU3ctiveCoolingModule(Module): - name = 'odroidxu3-fan' + name: str = 'odroidxu3-fan' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return target.file_exists('/sys/devices/odroid_fan.15/fan_mode') - def start(self): + def start(self) -> None: + """ + start fan + """ self.target.write_value('/sys/devices/odroid_fan.15/fan_mode', 0, verify=False) self.target.write_value('/sys/devices/odroid_fan.15/pwm_duty', 255, verify=False) def stop(self): + """ + stop fan + """ self.target.write_value('/sys/devices/odroid_fan.15/fan_mode', 0, verify=False) self.target.write_value('/sys/devices/odroid_fan.15/pwm_duty', 1, verify=False) diff --git a/devlib/module/cpufreq.py b/devlib/module/cpufreq.py index 2640a9a8e..9d798f281 100644 --- a/devlib/module/cpufreq.py +++ b/devlib/module/cpufreq.py @@ -1,4 +1,4 @@ -# Copyright 2014-2024 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,22 +17,38 @@ from devlib.exception import TargetStableError from devlib.utils.misc import memoized import devlib.utils.asyn as asyn - +from typing import (TYPE_CHECKING, Dict, List, Tuple, Union, + cast, Optional, Any, Set, Coroutine) +from collections.abc import AsyncGenerator +if TYPE_CHECKING: + from devlib.target import Target # a dict of governor name and a list of it tunables that can't be read -WRITE_ONLY_TUNABLES = { +WRITE_ONLY_TUNABLES: Dict[str, List[str]] = { 'interactive': ['boostpulse'] } class CpufreqModule(Module): - - name = 'cpufreq' + """ + ``cpufreq`` is the kernel subsystem for managing DVFS (Dynamic Voltage and + Frequency Scaling). It allows controlling frequency ranges and switching + policies (governors). The ``devlib`` module exposes the following interface + + .. note:: On ARM big.LITTLE systems, all cores on a cluster (usually all cores + of the same type) are in the same frequency domain, so setting + ``cpufreq`` state on one core on a cluster will affect all cores on + that cluster. Because of this, some devices only expose cpufreq sysfs + interface (which is what is used by the ``devlib`` module) on the + first cpu in a cluster. So to keep your scripts portable, always use + the fist (online) CPU in a cluster to set ``cpufreq`` state. + """ + name: str = 'cpufreq' @staticmethod @asyn.asyncf - async def probe(target): - paths = [ + async def probe(target: 'Target') -> bool: + paths_tmp: List[Tuple[bool, str]] = [ # x86 with Intel P-State driver (target.abi == 'x86_64', '/sys/devices/system/cpu/intel_pstate'), # Generic CPUFreq support (single policy) @@ -40,8 +56,8 @@ async def probe(target): # Generic CPUFreq support (per CPU policy) (True, '/sys/devices/system/cpu/cpu0/cpufreq'), ] - paths = [ - path[1] for path in paths + paths: List[str] = [ + path[1] for path in paths_tmp if path[0] ] @@ -52,38 +68,48 @@ async def probe(target): return any(exists.values()) - def __init__(self, target): + def __init__(self, target: 'Target'): super(CpufreqModule, self).__init__(target) - self._governor_tunables = {} + self._governor_tunables: Dict[str, Tuple[str, bool, List[str]]] = {} @asyn.asyncf @asyn.memoized_method - async def list_governors(self, cpu): - """Returns a list of governors supported by the cpu.""" + async def list_governors(self, cpu: Union[int, str]) -> List[str]: + """List cpufreq governors available for the specified cpu. Returns a list of + strings. + + :param cpu: The cpu; could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_available_governors'.format(cpu) - output = await self.target.read_value.asyn(sysfile) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_available_governors'.format(cpu) + output: str = await self.target.read_value.asyn(sysfile) return output.strip().split() @asyn.asyncf - async def get_governor(self, cpu): - """Returns the governor currently set for the specified CPU.""" + async def get_governor(self, cpu: Union[int, str]) -> str: + """ + Returns the name of the currently set governor for the specified cpu. + + :param cpu: The cpu; could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_governor'.format(cpu) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_governor'.format(cpu) return await self.target.read_value.asyn(sysfile) @asyn.asyncf - async def set_governor(self, cpu, governor, **kwargs): + async def set_governor(self, cpu: Union[int, str], governor: str, **kwargs) -> None: """ Set the governor for the specified CPU. See https://www.kernel.org/doc/Documentation/cpu-freq/governors.txt - :param cpu: The CPU for which the governor is to be set. This must be - the full name as it appears in sysfs, e.g. "cpu0". + :param cpu: The CPU for which the governor is to be set. It could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). :param governor: The name of the governor to be used. This must be - supported by the specific device. + supported by the specific device (as returned by ``list_governors()``. Additional keyword arguments can be used to specify governor tunables for governors that support them. @@ -98,7 +124,7 @@ async def set_governor(self, cpu, governor, **kwargs): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - supported = await self.list_governors.asyn(cpu) + supported: List[str] = await self.list_governors.asyn(cpu) if governor not in supported: raise TargetStableError('Governor {} not supported for cpu {}'.format(governor, cpu)) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_governor'.format(cpu) @@ -106,22 +132,20 @@ async def set_governor(self, cpu, governor, **kwargs): await self.set_governor_tunables.asyn(cpu, governor, **kwargs) @asyn.asynccontextmanager - async def use_governor(self, governor, cpus=None, **kwargs): + async def use_governor(self, governor: str, cpus: Optional[List[str]] = None, **kwargs) -> AsyncGenerator: """ Use a given governor, then restore previous governor(s) :param governor: Governor to use on all targeted CPUs (see :meth:`set_governor`) - :type governor: str :param cpus: CPUs affected by the governor change (all by default) - :type cpus: list :Keyword Arguments: Governor tunables, See :meth:`set_governor_tunables` """ if not cpus: cpus = await self.target.list_online_cpus.asyn() - async def get_cpu_info(cpu): + async def get_cpu_info(cpu) -> List[Any]: return await self.target.async_manager.concurrently(( self.get_affected_cpus.asyn(cpu), self.get_governor.asyn(cpu), @@ -131,12 +155,12 @@ async def get_cpu_info(cpu): self.get_frequency.asyn(cpu), )) - cpus_infos = await self.target.async_manager.map_concurrently(get_cpu_info, cpus) + cpus_infos: Dict[int, List[Any]] = await self.target.async_manager.map_concurrently(get_cpu_info, cpus) # Setting a governor & tunables for a cpu will set them for all cpus in # the same cpufreq policy, so only manipulating one cpu per domain is # enough - domains = set( + domains: Set[Any] = set( info[0][0] for info in cpus_infos.values() ) @@ -149,7 +173,7 @@ async def get_cpu_info(cpu): try: yield finally: - async def set_per_cpu_tunables(cpu): + async def set_per_cpu_tunables(cpu: int) -> None: domain, prev_gov, tunables, freq = cpus_infos[cpu] # Per-cpu tunables are safe to set concurrently await self.set_governor_tunables.asyn(cpu, prev_gov, per_cpu=True, **tunables) @@ -157,7 +181,7 @@ async def set_per_cpu_tunables(cpu): if prev_gov == "userspace": await self.set_frequency.asyn(cpu, freq) - per_cpu_tunables = self.target.async_manager.concurrently( + per_cpu_tunables: Coroutine = self.target.async_manager.concurrently( set_per_cpu_tunables(cpu) for cpu in domains ) @@ -165,14 +189,14 @@ async def set_per_cpu_tunables(cpu): # Non-per-cpu tunables have to be set one after the other, for each # governor that we had to deal with. - global_tunables = { + global_tunables_dict: Dict[str, Tuple[int, Dict[str, List[str]]]] = { prev_gov: (cpu, tunables) for cpu, (domain, prev_gov, tunables, freq) in cpus_infos.items() } - global_tunables = self.target.async_manager.concurrently( + global_tunables: Coroutine = self.target.async_manager.concurrently( self.set_governor_tunables.asyn(cpu, gov, per_cpu=False, **tunables) - for gov, (cpu, tunables) in global_tunables.items() + for gov, (cpu, tunables) in global_tunables_dict.items() ) global_tunables.__qualname__ = 'CpufreqModule.use_governor..global_tunables' @@ -188,7 +212,11 @@ async def set_per_cpu_tunables(cpu): ) @asyn.asyncf - async def _list_governor_tunables(self, cpu, governor=None): + async def _list_governor_tunables(self, cpu: Union[int, str], + governor: Optional[str] = None) -> Tuple[str, bool, List[str]]: + """ + helper function for list_governor_tunables + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) @@ -196,6 +224,8 @@ async def _list_governor_tunables(self, cpu, governor=None): governor = await self.get_governor.asyn(cpu) try: + if not governor: + raise TargetStableError return self._governor_tunables[governor] except KeyError: for per_cpu, path in ( @@ -204,7 +234,7 @@ async def _list_governor_tunables(self, cpu, governor=None): (False, '/sys/devices/system/cpu/cpufreq/{}'.format(governor)), ): try: - tunables = await self.target.list_directory.asyn(path) + tunables: List[str] = await self.target.list_directory.asyn(path) except TargetStableError: continue else: @@ -213,34 +243,49 @@ async def _list_governor_tunables(self, cpu, governor=None): per_cpu = False tunables = [] - data = (governor, per_cpu, tunables) - self._governor_tunables[governor] = data + data: Tuple[str, bool, List[str]] = (cast(str, governor), per_cpu, tunables) + if governor: + self._governor_tunables[governor] = data return data @asyn.asyncf - async def list_governor_tunables(self, cpu): - """Returns a list of tunables available for the governor on the specified CPU.""" + async def list_governor_tunables(self, cpu: Union[int, str]) -> Tuple[str, bool, List[str]]: + """ + List the tunables for the specified cpu's current governor. + + :param cpu: The cpu; could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). + """ _, _, tunables = await self._list_governor_tunables.asyn(cpu) return tunables @asyn.asyncf - async def get_governor_tunables(self, cpu): + async def get_governor_tunables(self, cpu: Union[int, str]) -> Dict[str, List[str]]: + """ + Return a dict with the values of the specified CPU's current governor. + + :param cpu: The cpu; could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) + governor: str + tunable_list: List[str] governor, _, tunable_list = await self._list_governor_tunables.asyn(cpu) - write_only = set(WRITE_ONLY_TUNABLES.get(governor, [])) + write_only: Set[str] = set(WRITE_ONLY_TUNABLES.get(governor, [])) tunable_list = [ tunable for tunable in tunable_list if tunable not in write_only ] - tunables = {} - async def get_tunable(tunable): + tunables: Dict[str, List[str]] = {} + + async def get_tunable(tunable: str) -> str: try: - path = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) - x = await self.target.read_value.asyn(path) + path: str = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) + x: str = await self.target.read_value.asyn(path) except TargetStableError: # May be an older kernel path = '/sys/devices/system/cpu/cpufreq/{}/{}'.format(governor, tunable) x = await self.target.read_value.asyn(path) @@ -250,7 +295,8 @@ async def get_tunable(tunable): return tunables @asyn.asyncf - async def set_governor_tunables(self, cpu, governor=None, per_cpu=None, **kwargs): + async def set_governor_tunables(self, cpu: Union[int, str], governor: Optional[str] = None, + per_cpu: Optional[bool] = None, **kwargs) -> None: """ Set tunables for the specified governor. Tunables should be specified as keyword arguments. Which tunables and values are valid depends on the @@ -276,6 +322,8 @@ async def set_governor_tunables(self, cpu, governor=None, per_cpu=None, **kwargs if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) + gov_per_cpu: bool + valid_tunables: List[str] governor, gov_per_cpu, valid_tunables = await self._list_governor_tunables.asyn(cpu, governor=governor) for tunable, value in kwargs.items(): if tunable in valid_tunables: @@ -283,34 +331,38 @@ async def set_governor_tunables(self, cpu, governor=None, per_cpu=None, **kwargs continue if gov_per_cpu: - path = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) + path: str = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) else: path = '/sys/devices/system/cpu/cpufreq/{}/{}'.format(governor, tunable) await self.target.write_value.asyn(path, value) else: - message = 'Unexpected tunable {} for governor {} on {}.\n'.format(tunable, governor, cpu) + message: str = 'Unexpected tunable {} for governor {} on {}.\n'.format(tunable, governor, cpu) message += 'Available tunables are: {}'.format(valid_tunables) raise TargetStableError(message) @asyn.asyncf @asyn.memoized_method - async def list_frequencies(self, cpu): - """Returns a sorted list of frequencies supported by the cpu or an empty list - if not could be found.""" + async def list_frequencies(self, cpu: Union[int, str]) -> List[int]: + """ + Returns a sorted list of frequencies supported by the cpu or an empty list + if not could be found. + :param cpu: The cpu; could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) try: - cmd = 'cat /sys/devices/system/cpu/{}/cpufreq/scaling_available_frequencies'.format(cpu) - output = await self.target.execute.asyn(cmd) - available_frequencies = list(map(int, output.strip().split())) # pylint: disable=E1103 + cmd: str = 'cat /sys/devices/system/cpu/{}/cpufreq/scaling_available_frequencies'.format(cpu) + output: str = await self.target.execute.asyn(cmd) + available_frequencies: List[int] = list(map(int, output.strip().split())) # pylint: disable=E1103 except TargetStableError: # On some devices scaling_frequencies is not generated. # http://adrynalyne-teachtofish.blogspot.co.uk/2011/11/how-to-enable-scalingavailablefrequenci.html # Fall back to parsing stats/time_in_state - path = '/sys/devices/system/cpu/{}/cpufreq/stats/time_in_state'.format(cpu) + path: str = '/sys/devices/system/cpu/{}/cpufreq/stats/time_in_state'.format(cpu) try: - out_iter = (await self.target.read_value.asyn(path)).split() + out_iter: List[str] = cast(str, (await self.target.read_value.asyn(path))).split() except TargetStableError: if not self.target.file_exists(path): # Probably intel_pstate. Can't get available freqs. @@ -321,7 +373,7 @@ async def list_frequencies(self, cpu): return sorted(available_frequencies) @memoized - def get_max_available_frequency(self, cpu): + def get_max_available_frequency(self, cpu: Union[str, int]) -> Optional[int]: """ Returns the maximum available frequency for a given core or None if could not be found. @@ -330,16 +382,16 @@ def get_max_available_frequency(self, cpu): return max(freqs) if freqs else None @memoized - def get_min_available_frequency(self, cpu): + def get_min_available_frequency(self, cpu: Union[str, int]) -> Optional[int]: """ Returns the minimum available frequency for a given core or None if could not be found. """ - freqs = self.list_frequencies(cpu) + freqs: List[int] = self.list_frequencies(cpu) return min(freqs) if freqs else None @asyn.asyncf - async def get_min_frequency(self, cpu): + async def get_min_frequency(self, cpu: Union[str, int]) -> int: """ Returns the min frequency currently set for the specified CPU. @@ -352,11 +404,11 @@ async def get_min_frequency(self, cpu): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_min_freq'.format(cpu) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_min_freq'.format(cpu) return await self.target.read_int.asyn(sysfile) @asyn.asyncf - async def set_min_frequency(self, cpu, frequency, exact=True): + async def set_min_frequency(self, cpu: Union[str, int], frequency: Union[int, str], exact: bool = True) -> None: """ Set's the minimum value for CPU frequency. Actual frequency will depend on the Governor used and may vary during execution. The value should be @@ -375,20 +427,20 @@ async def set_min_frequency(self, cpu, frequency, exact=True): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - available_frequencies = await self.list_frequencies.asyn(cpu) + available_frequencies: List[int] = await self.list_frequencies.asyn(cpu) try: value = int(frequency) if exact and available_frequencies and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(cpu, - value, - available_frequencies)) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_min_freq'.format(cpu) + value, + available_frequencies)) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_min_freq'.format(cpu) await self.target.write_value.asyn(sysfile, value) except ValueError: raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) @asyn.asyncf - async def get_frequency(self, cpu, cpuinfo=False): + async def get_frequency(self, cpu: Union[str, int], cpuinfo: bool = False) -> int: """ Returns the current frequency currently set for the specified CPU. @@ -405,18 +457,20 @@ async def get_frequency(self, cpu, cpuinfo=False): if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/{}'.format( - cpu, - 'cpuinfo_cur_freq' if cpuinfo else 'scaling_cur_freq') + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/{}'.format( + cpu, + 'cpuinfo_cur_freq' if cpuinfo else 'scaling_cur_freq') return await self.target.read_int.asyn(sysfile) @asyn.asyncf - async def set_frequency(self, cpu, frequency, exact=True): + async def set_frequency(self, cpu: Union[str, int], frequency: Union[int, str], exact: bool = True) -> None: """ - Set's the minimum value for CPU frequency. Actual frequency will + Sets the minimum value for CPU frequency. Actual frequency will depend on the Governor used and may vary during execution. The value should be either an int or a string representing an integer. + `set_frequency`` is only available if the current governor is ``userspace``. + If ``exact`` flag is set (the default), the Value must also be supported by the device. The available frequencies can be obtained by calling get_frequencies() or examining @@ -435,16 +489,16 @@ async def set_frequency(self, cpu, frequency, exact=True): try: value = int(frequency) if exact: - available_frequencies = await self.list_frequencies.asyn(cpu) + available_frequencies: List[int] = await self.list_frequencies.asyn(cpu) if available_frequencies and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(cpu, - value, - available_frequencies)) + value, + available_frequencies)) if await self.get_governor.asyn(cpu) != 'userspace': raise TargetStableError('Can\'t set {} frequency; governor must be "userspace"'.format(cpu)) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_setspeed'.format(cpu) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_setspeed'.format(cpu) await self.target.write_value.asyn(sysfile, value, verify=False) - cpuinfo = await self.get_frequency.asyn(cpu, cpuinfo=True) + cpuinfo: int = await self.get_frequency.asyn(cpu, cpuinfo=True) if cpuinfo != value: self.logger.warning( 'The cpufreq value has not been applied properly cpuinfo={} request={}'.format(cpuinfo, value)) @@ -452,7 +506,7 @@ async def set_frequency(self, cpu, frequency, exact=True): raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) @asyn.asyncf - async def get_max_frequency(self, cpu): + async def get_max_frequency(self, cpu: Union[str, int]) -> int: """ Returns the max frequency currently set for the specified CPU. @@ -464,13 +518,14 @@ async def get_max_frequency(self, cpu): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_max_freq'.format(cpu) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_max_freq'.format(cpu) return await self.target.read_int.asyn(sysfile) @asyn.asyncf - async def set_max_frequency(self, cpu, frequency, exact=True): + async def set_max_frequency(self, cpu: Union[str, int], + frequency: Union[str, int], exact: bool = True) -> None: """ - Set's the minimum value for CPU frequency. Actual frequency will + Set's the maximum value for CPU frequency. Actual frequency will depend on the Governor used and may vary during execution. The value should be either an int or a string representing an integer. The Value must also be supported by the device. The available frequencies can be obtained by calling @@ -492,58 +547,58 @@ async def set_max_frequency(self, cpu, frequency, exact=True): value = int(frequency) if exact and available_frequencies and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(cpu, - value, - available_frequencies)) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_max_freq'.format(cpu) + value, + available_frequencies)) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_max_freq'.format(cpu) await self.target.write_value.asyn(sysfile, value) except ValueError: raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) @asyn.asyncf - async def set_governor_for_cpus(self, cpus, governor, **kwargs): + async def set_governor_for_cpus(self, cpus: List[Union[str, int]], governor: str, **kwargs) -> None: """ Set the governor for the specified list of CPUs. See https://www.kernel.org/doc/Documentation/cpu-freq/governors.txt :param cpus: The list of CPU for which the governor is to be set. """ - await self.target.async_manager.map_concurrently( + await self.target.async_manager.concurrently( self.set_governor(cpu, governor, **kwargs) for cpu in sorted(set(cpus)) ) @asyn.asyncf - async def set_frequency_for_cpus(self, cpus, freq, exact=False): + async def set_frequency_for_cpus(self, cpus: List[Union[int, str]], freq: int, exact: bool = False) -> None: """ Set the frequency for the specified list of CPUs. See https://www.kernel.org/doc/Documentation/cpu-freq/governors.txt :param cpus: The list of CPU for which the frequency has to be set. """ - await self.target.async_manager.map_concurrently( + await self.target.async_manager.concurrently( self.set_frequency(cpu, freq, exact) for cpu in sorted(set(cpus)) ) @asyn.asyncf - async def set_all_frequencies(self, freq): + async def set_all_frequencies(self, freq: int) -> None: """ Set the specified (minimum) frequency for all the (online) CPUs """ # pylint: disable=protected-access return await self.target._execute_util.asyn( - 'cpufreq_set_all_frequencies {}'.format(freq), - as_root=True) + 'cpufreq_set_all_frequencies {}'.format(freq), + as_root=True) @asyn.asyncf - async def get_all_frequencies(self): + async def get_all_frequencies(self) -> Dict[str, str]: """ Get the current frequency for all the (online) CPUs """ # pylint: disable=protected-access - output = await self.target._execute_util.asyn( - 'cpufreq_get_all_frequencies', as_root=True) - frequencies = {} + output: str = await self.target._execute_util.asyn( + 'cpufreq_get_all_frequencies', as_root=True) + frequencies: Dict[str, str] = {} for x in output.splitlines(): kv = x.split(' ') if kv[0] == '': @@ -552,7 +607,7 @@ async def get_all_frequencies(self): return frequencies @asyn.asyncf - async def set_all_governors(self, governor): + async def set_all_governors(self, governor: str) -> None: """ Set the specified governor for all the (online) CPUs """ @@ -563,10 +618,10 @@ async def set_all_governors(self, governor): as_root=True) except TargetStableError as e: if ("echo: I/O error" in str(e) or - "write error: Invalid argument" in str(e)): + "write error: Invalid argument" in str(e)): - cpus_unsupported = [c for c in await self.target.list_online_cpus.asyn() - if governor not in await self.list_governors.asyn(c)] + cpus_unsupported: List[int] = [c for c in await self.target.list_online_cpus.asyn() + if governor not in await self.list_governors.asyn(c)] raise TargetStableError("Governor {} unsupported for CPUs {}".format( governor, cpus_unsupported)) else: @@ -579,7 +634,7 @@ async def get_all_governors(self): """ # pylint: disable=protected-access output = await self.target._execute_util.asyn( - 'cpufreq_get_all_governors', as_root=True) + 'cpufreq_get_all_governors', as_root=True) governors = {} for x in output.splitlines(): kv = x.split(' ') @@ -597,7 +652,7 @@ async def trace_frequencies(self): return await self.target._execute_util.asyn('cpufreq_trace_all_frequencies', as_root=True) @asyn.asyncf - async def get_affected_cpus(self, cpu): + async def get_affected_cpus(self, cpu: Union[str, int]) -> List[int]: """ Get the online CPUs that share a frequency domain with the given CPU """ @@ -611,7 +666,7 @@ async def get_affected_cpus(self, cpu): @asyn.asyncf @asyn.memoized_method - async def get_related_cpus(self, cpu): + async def get_related_cpus(self, cpu: Union[str, int]) -> List[int]: """ Get the CPUs that share a frequency domain with the given CPU """ @@ -620,11 +675,11 @@ async def get_related_cpus(self, cpu): sysfile = '/sys/devices/system/cpu/{}/cpufreq/related_cpus'.format(cpu) - return [int(c) for c in (await self.target.read_value.asyn(sysfile)).split()] + return [int(c) for c in cast(str, (await self.target.read_value.asyn(sysfile))).split()] @asyn.asyncf @asyn.memoized_method - async def get_driver(self, cpu): + async def get_driver(self, cpu: Union[str, int]) -> str: """ Get the name of the driver used by this cpufreq policy. """ @@ -633,16 +688,16 @@ async def get_driver(self, cpu): sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_driver'.format(cpu) - return (await self.target.read_value.asyn(sysfile)).strip() + return cast(str, (await self.target.read_value.asyn(sysfile))).strip() @asyn.asyncf - async def iter_domains(self): + async def iter_domains(self) -> AsyncGenerator[Set[int], None]: """ Iterate over the frequency domains in the system """ cpus = set(range(self.target.number_of_cpus)) while cpus: cpu = next(iter(cpus)) # pylint: disable=stop-iteration-return - domain = await self.target.cpufreq.get_related_cpus.asyn(cpu) + domain: Set[int] = await cast(CpufreqModule, self.target.cpufreq).get_related_cpus.asyn(cpu) yield domain cpus = cpus.difference(domain) diff --git a/devlib/module/cpuidle.py b/devlib/module/cpuidle.py index a7d0fef64..53a35a5fc 100644 --- a/devlib/module/cpuidle.py +++ b/devlib/module/cpuidle.py @@ -1,4 +1,4 @@ -# Copyright 2014-2018 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ # limitations under the License. # # pylint: disable=attribute-defined-outside-init -from past.builtins import basestring from operator import attrgetter from pprint import pformat @@ -23,24 +22,27 @@ from devlib.utils.types import integer, boolean from devlib.utils.misc import memoized import devlib.utils.asyn as asyn +from typing import Optional, TYPE_CHECKING, Union, List +if TYPE_CHECKING: + from devlib.target import Target class CpuidleState(object): @property - def usage(self): + def usage(self) -> int: return integer(self.get('usage')) @property - def time(self): + def time(self) -> int: return integer(self.get('time')) @property - def is_enabled(self): + def is_enabled(self) -> bool: return not boolean(self.get('disable')) @property - def ordinal(self): + def ordinal(self) -> int: i = len(self.id) while self.id[i - 1].isdigit(): i -= 1 @@ -48,7 +50,8 @@ def ordinal(self): raise ValueError('invalid idle state name: "{}"'.format(self.id)) return int(self.id[i:]) - def __init__(self, target, index, path, name, desc, power, latency, residency): + def __init__(self, target: 'Target', index: int, path: str, name: str, + desc: str, power: int, latency: int, residency: Optional[int]): self.target = target self.index = index self.path = path @@ -57,31 +60,43 @@ def __init__(self, target, index, path, name, desc, power, latency, residency): self.power = power self.latency = latency self.residency = residency - self.id = self.target.path.basename(self.path) - self.cpu = self.target.path.basename(self.target.path.dirname(path)) + self.id: str = self.target.path.basename(self.path) if self.target.path else '' + self.cpu: str = self.target.path.basename(self.target.path.dirname(path)) if self.target.path else '' @asyn.asyncf - async def enable(self): + async def enable(self) -> None: + """ + enable idle state + """ await self.set.asyn('disable', 0) @asyn.asyncf - async def disable(self): + async def disable(self) -> None: + """ + disable idle state + """ await self.set.asyn('disable', 1) @asyn.asyncf - async def get(self, prop): - property_path = self.target.path.join(self.path, prop) + async def get(self, prop: str) -> str: + """ + get the property + """ + property_path = self.target.path.join(self.path, prop) if self.target.path else '' return await self.target.read_value.asyn(property_path) @asyn.asyncf - async def set(self, prop, value): - property_path = self.target.path.join(self.path, prop) + async def set(self, prop: str, value: str) -> None: + """ + set the property + """ + property_path = self.target.path.join(self.path, prop) if self.target.path else '' await self.target.write_value.asyn(property_path, value) def __eq__(self, other): if isinstance(other, CpuidleState): return (self.name == other.name) and (self.desc == other.desc) - elif isinstance(other, basestring): + elif isinstance(other, str): return (self.name == other) or (self.desc == other) else: return False @@ -96,19 +111,23 @@ def __str__(self): class Cpuidle(Module): - + """ + ``cpuidle`` is the kernel subsystem for managing CPU low power (idle) states. + """ name = 'cpuidle' root_path = '/sys/devices/system/cpu/cpuidle' @staticmethod @asyn.asyncf - async def probe(target): + async def probe(target: 'Target') -> bool: return await target.file_exists.asyn(Cpuidle.root_path) - def __init__(self, target): + def __init__(self, target: 'Target'): super(Cpuidle, self).__init__(target) - basepath = '/sys/devices/system/cpu/' + basepath: str = '/sys/devices/system/cpu/' + # FIXME - annotating the values_tree based on read_tree_values return type is causing errors due to recursive + # definition of the Node type. leaving it out for now values_tree = self.target.read_tree_values(basepath, depth=4, check_exit_code=False) self._states = { @@ -118,7 +137,7 @@ def __init__(self, target): self.target, # state_name is formatted as "state42" index=int(state_name[len('state'):]), - path=self.target.path.join(basepath, cpu_name, 'cpuidle', state_name), + path=self.target.path.join(basepath, cpu_name, 'cpuidle', state_name) if self.target.path else '', name=state_node['name'], desc=state_node['desc'], power=int(state_node['power']), @@ -137,12 +156,18 @@ def __init__(self, target): self.logger.debug('Adding cpuidle states:\n{}'.format(pformat(self._states))) - def get_states(self, cpu=0): + def get_states(self, cpu: Union[int, str] = 0) -> List[CpuidleState]: + """ + get the cpu idle states + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) return self._states.get(cpu, []) - def get_state(self, state, cpu=0): + def get_state(self, state: Union[str, int], cpu: Union[str, int] = 0) -> CpuidleState: + """ + get the specific cpuidle state values + """ if isinstance(state, int): try: return self.get_states(cpu)[state] @@ -155,29 +180,41 @@ def get_state(self, state, cpu=0): raise ValueError('Cpuidle state {} does not exist'.format(state)) @asyn.asyncf - async def enable(self, state, cpu=0): + async def enable(self, state: Union[str, int], cpu: Union[str, int] = 0) -> None: + """ + enable the specific cpu idle state + """ await self.get_state(state, cpu).enable.asyn() @asyn.asyncf - async def disable(self, state, cpu=0): + async def disable(self, state: Union[str, int], cpu: Union[str, int] = 0) -> None: + """ + disable the specific cpu idle state + """ await self.get_state(state, cpu).disable.asyn() @asyn.asyncf - async def enable_all(self, cpu=0): + async def enable_all(self, cpu: Union[str, int] = 0) -> None: + """ + enable all the cpu idle states + """ await self.target.async_manager.concurrently( state.enable.asyn() for state in self.get_states(cpu) ) @asyn.asyncf - async def disable_all(self, cpu=0): + async def disable_all(self, cpu: Union[str, int] = 0) -> None: + """ + disable all cpu idle states + """ await self.target.async_manager.concurrently( state.disable.asyn() for state in self.get_states(cpu) ) @asyn.asyncf - async def perturb_cpus(self): + async def perturb_cpus(self) -> None: """ Momentarily wake each CPU. Ensures cpu_idle events in trace file. """ @@ -185,25 +222,30 @@ async def perturb_cpus(self): await self.target._execute_util.asyn('cpuidle_wake_all_cpus') @asyn.asyncf - async def get_driver(self): - return await self.target.read_value.asyn(self.target.path.join(self.root_path, 'current_driver')) + async def get_driver(self) -> Optional[str]: + """ + get the current driver of idle states + """ + if self.target.path: + return await self.target.read_value.asyn(self.target.path.join(self.root_path, 'current_driver')) + return None @memoized - def list_governors(self): + def list_governors(self) -> List[str]: """Returns a list of supported idle governors.""" - sysfile = self.target.path.join(self.root_path, 'available_governors') - output = self.target.read_value(sysfile) + sysfile: str = self.target.path.join(self.root_path, 'available_governors') if self.target.path else '' + output: str = self.target.read_value(sysfile) return output.strip().split() @asyn.asyncf - async def get_governor(self): + async def get_governor(self) -> str: """Returns the currently selected idle governor.""" - path = self.target.path.join(self.root_path, 'current_governor_ro') + path = self.target.path.join(self.root_path, 'current_governor_ro') if self.target.path else '' if not await self.target.file_exists.asyn(path): - path = self.target.path.join(self.root_path, 'current_governor') + path = self.target.path.join(self.root_path, 'current_governor') if self.target.path else '' return await self.target.read_value.asyn(path) - def set_governor(self, governor): + def set_governor(self, governor: str) -> None: """ Set the idle governor for the system. @@ -213,8 +255,8 @@ def set_governor(self, governor): :raises TargetStableError if governor is not supported by the CPU, or if, for some reason, the governor could not be set. """ - supported = self.list_governors() + supported: List[str] = self.list_governors() if governor not in supported: raise TargetStableError('Governor {} not supported'.format(governor)) - sysfile = self.target.path.join(self.root_path, 'current_governor') + sysfile: str = self.target.path.join(self.root_path, 'current_governor') if self.target.path else '' self.target.write_value(sysfile, governor) diff --git a/devlib/module/devfreq.py b/devlib/module/devfreq.py index 00c3154c8..0fd75e7ac 100644 --- a/devlib/module/devfreq.py +++ b/devlib/module/devfreq.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,14 +15,20 @@ from devlib.module import Module from devlib.exception import TargetStableError from devlib.utils.misc import memoized +from typing import TYPE_CHECKING, List, Union, Dict +if TYPE_CHECKING: + from devlib.target import Target -class DevfreqModule(Module): +class DevfreqModule(Module): + """ + The devfreq framework in Linux is used for dynamic voltage and frequency scaling (DVFS) of various devices. + """ name = 'devfreq' @staticmethod - def probe(target): - path = '/sys/class/devfreq/' + def probe(target: 'Target') -> bool: + path: str = '/sys/class/devfreq/' if not target.file_exists(path): return False @@ -33,26 +39,26 @@ def probe(target): return True @memoized - def list_devices(self): + def list_devices(self) -> List[str]: """Returns a list of devfreq devices supported by the target platform.""" sysfile = '/sys/class/devfreq/' return self.target.list_directory(sysfile) @memoized - def list_governors(self, device): + def list_governors(self, device: str) -> List[str]: """Returns a list of governors supported by the device.""" - sysfile = '/sys/class/devfreq/{}/available_governors'.format(device) - output = self.target.read_value(sysfile) + sysfile: str = '/sys/class/devfreq/{}/available_governors'.format(device) + output: str = self.target.read_value(sysfile) return output.strip().split() - def get_governor(self, device): + def get_governor(self, device: Union[str, int]) -> str: """Returns the governor currently set for the specified device.""" if isinstance(device, int): device = 'device{}'.format(device) sysfile = '/sys/class/devfreq/{}/governor'.format(device) return self.target.read_value(sysfile) - def set_governor(self, device, governor): + def set_governor(self, device: str, governor: str) -> None: """ Set the governor for the specified device. @@ -68,25 +74,25 @@ def set_governor(self, device, governor): for some reason, the governor could not be set. """ - supported = self.list_governors(device) + supported: List[str] = self.list_governors(device) if governor not in supported: raise TargetStableError('Governor {} not supported for device {}'.format(governor, device)) - sysfile = '/sys/class/devfreq/{}/governor'.format(device) + sysfile: str = '/sys/class/devfreq/{}/governor'.format(device) self.target.write_value(sysfile, governor) @memoized - def list_frequencies(self, device): + def list_frequencies(self, device: str) -> List[int]: """ Returns a list of frequencies supported by the device or an empty list if could not be found. """ - cmd = 'cat /sys/class/devfreq/{}/available_frequencies'.format(device) - output = self.target.execute(cmd) - available_frequencies = [int(freq) for freq in output.strip().split()] + cmd: str = 'cat /sys/class/devfreq/{}/available_frequencies'.format(device) + output: str = self.target.execute(cmd) + available_frequencies: List[int] = [int(freq) for freq in output.strip().split()] return available_frequencies - def get_min_frequency(self, device): + def get_min_frequency(self, device: str) -> int: """ Returns the min frequency currently set for the specified device. @@ -100,7 +106,7 @@ def get_min_frequency(self, device): sysfile = '/sys/class/devfreq/{}/min_freq'.format(device) return self.target.read_int(sysfile) - def set_min_frequency(self, device, frequency, exact=True): + def set_min_frequency(self, device: str, frequency: Union[int, str], exact: bool = True) -> None: """ Sets the minimum value for device frequency. Actual frequency will depend on the thermal governor used and may vary during execution. The @@ -117,19 +123,19 @@ def set_min_frequency(self, device, frequency, exact=True): :raises: ValueError if ``frequency`` is not an integer. """ - available_frequencies = self.list_frequencies(device) + available_frequencies: List[int] = self.list_frequencies(device) try: value = int(frequency) if exact and available_frequencies and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(device, - value, - available_frequencies)) - sysfile = '/sys/class/devfreq/{}/min_freq'.format(device) + value, + available_frequencies)) + sysfile: str = '/sys/class/devfreq/{}/min_freq'.format(device) self.target.write_value(sysfile, value) except ValueError: raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) - def get_frequency(self, device): + def get_frequency(self, device: str) -> int: """ Returns the current frequency currently set for the specified device. @@ -140,10 +146,10 @@ def get_frequency(self, device): :raises: TargetStableError if for some reason the frequency could not be read. """ - sysfile = '/sys/class/devfreq/{}/cur_freq'.format(device) + sysfile: str = '/sys/class/devfreq/{}/cur_freq'.format(device) return self.target.read_int(sysfile) - def get_max_frequency(self, device): + def get_max_frequency(self, device: str) -> int: """ Returns the max frequency currently set for the specified device. @@ -153,10 +159,10 @@ def get_max_frequency(self, device): :raises: TargetStableError if for some reason the frequency could not be read. """ - sysfile = '/sys/class/devfreq/{}/max_freq'.format(device) + sysfile: str = '/sys/class/devfreq/{}/max_freq'.format(device) return self.target.read_int(sysfile) - def set_max_frequency(self, device, frequency, exact=True): + def set_max_frequency(self, device: str, frequency: Union[int, str], exact: bool = True) -> None: """ Sets the maximum value for device frequency. Actual frequency will depend on the Governor used and may vary during execution. The value @@ -173,7 +179,7 @@ def set_max_frequency(self, device, frequency, exact=True): :raises: ValueError if ``frequency`` is not an integer. """ - available_frequencies = self.list_frequencies(device) + available_frequencies: List[int] = self.list_frequencies(device) try: value = int(frequency) except ValueError: @@ -181,12 +187,12 @@ def set_max_frequency(self, device, frequency, exact=True): if exact and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(device, - value, - available_frequencies)) - sysfile = '/sys/class/devfreq/{}/max_freq'.format(device) + value, + available_frequencies)) + sysfile: str = '/sys/class/devfreq/{}/max_freq'.format(device) self.target.write_value(sysfile, value) - def set_governor_for_devices(self, devices, governor): + def set_governor_for_devices(self, devices: List[str], governor: str) -> None: """ Set the governor for the specified list of devices. @@ -195,7 +201,7 @@ def set_governor_for_devices(self, devices, governor): for device in devices: self.set_governor(device, governor) - def set_all_governors(self, governor): + def set_all_governors(self, governor: str) -> None: """ Set the specified governor for all the (available) devices """ @@ -204,22 +210,22 @@ def set_all_governors(self, governor): 'devfreq_set_all_governors {}'.format(governor), as_root=True) except TargetStableError as e: if ("echo: I/O error" in str(e) or - "write error: Invalid argument" in str(e)): + "write error: Invalid argument" in str(e)): - devs_unsupported = [d for d in self.target.list_devices() - if governor not in self.list_governors(d)] + devs_unsupported: List[str] = [d for d in self.list_devices() + if governor not in self.list_governors(d)] raise TargetStableError("Governor {} unsupported for devices {}".format( governor, devs_unsupported)) else: raise - def get_all_governors(self): + def get_all_governors(self) -> Dict[str, str]: """ Get the current governor for all the (online) CPUs """ - output = self.target._execute_util( # pylint: disable=protected-access - 'devfreq_get_all_governors', as_root=True) - governors = {} + output: str = self.target._execute_util( # pylint: disable=protected-access + 'devfreq_get_all_governors', as_root=True) + governors: Dict[str, str] = {} for x in output.splitlines(): kv = x.split(' ') if kv[0] == '': @@ -227,7 +233,7 @@ def get_all_governors(self): governors[kv[0]] = kv[1] return governors - def set_frequency_for_devices(self, devices, freq, exact=False): + def set_frequency_for_devices(self, devices: List[str], freq: Union[int, str], exact: bool = False) -> None: """ Set the frequency for the specified list of devices. @@ -237,21 +243,21 @@ def set_frequency_for_devices(self, devices, freq, exact=False): self.set_max_frequency(device, freq, exact) self.set_min_frequency(device, freq, exact) - def set_all_frequencies(self, freq): + def set_all_frequencies(self, freq: Union[int, str]) -> None: """ Set the specified (minimum) frequency for all the (available) devices """ return self.target._execute_util( # pylint: disable=protected-access - 'devfreq_set_all_frequencies {}'.format(freq), - as_root=True) + 'devfreq_set_all_frequencies {}'.format(freq), + as_root=True) - def get_all_frequencies(self): + def get_all_frequencies(self) -> Dict[str, str]: """ Get the current frequency for all the (available) devices """ - output = self.target._execute_util( # pylint: disable=protected-access - 'devfreq_get_all_frequencies', as_root=True) - frequencies = {} + output: str = self.target._execute_util( # pylint: disable=protected-access + 'devfreq_get_all_frequencies', as_root=True) + frequencies: Dict[str, str] = {} for x in output.splitlines(): kv = x.split(' ') if kv[0] == '': diff --git a/devlib/module/gpufreq.py b/devlib/module/gpufreq.py index 9f0a9529d..b8be14f4d 100644 --- a/devlib/module/gpufreq.py +++ b/devlib/module/gpufreq.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,41 +31,50 @@ from devlib.module import Module from devlib.exception import TargetStableError from devlib.utils.misc import memoized +from typing import TYPE_CHECKING, List +if TYPE_CHECKING: + from devlib.target import Target -class GpufreqModule(Module): +class GpufreqModule(Module): + """ + module that handles gpu frequency scaling + """ name = 'gpufreq' path = '' - def __init__(self, target): + def __init__(self, target: 'Target'): super(GpufreqModule, self).__init__(target) - frequencies_str = self.target.read_value("/sys/kernel/gpu/gpu_freq_table") - self.frequencies = list(map(int, frequencies_str.split(" "))) + frequencies_str: str = self.target.read_value("/sys/kernel/gpu/gpu_freq_table") + self.frequencies: List[int] = list(map(int, frequencies_str.split(" "))) self.frequencies.sort() - self.governors = self.target.read_value("/sys/kernel/gpu/gpu_available_governor").split(" ") + self.governors: List[str] = self.target.read_value("/sys/kernel/gpu/gpu_available_governor").split(" ") @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: # kgsl/Adreno - probe_path = '/sys/kernel/gpu/' + probe_path: str = '/sys/kernel/gpu/' if target.file_exists(probe_path): - model = target.read_value(probe_path + "gpu_model") + model: str = target.read_value(probe_path + "gpu_model") if re.search('adreno', model, re.IGNORECASE): return True return False - def set_governor(self, governor): + def set_governor(self, governor: str) -> None: + """ + set the governor to the gpu + """ if governor not in self.governors: raise TargetStableError('Governor {} not supported for gpu'.format(governor)) self.target.write_value("/sys/kernel/gpu/gpu_governor", governor) - def get_frequencies(self): + def get_frequencies(self) -> List[int]: """ Returns the list of frequencies that the GPU can have """ return self.frequencies - def get_current_frequency(self): + def get_current_frequency(self) -> int: """ Returns the current frequency currently set for the GPU. @@ -79,7 +88,7 @@ def get_current_frequency(self): return int(self.target.read_value("/sys/kernel/gpu/gpu_clock")) @memoized - def get_model_name(self): + def get_model_name(self) -> str: """ Returns the model name reported by the GPU. """ diff --git a/devlib/module/hotplug.py b/devlib/module/hotplug.py index 7d5ea5f64..731770bf6 100644 --- a/devlib/module/hotplug.py +++ b/devlib/module/hotplug.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,32 +15,48 @@ from devlib.module import Module from devlib.exception import TargetTransientError +from typing import TYPE_CHECKING, Dict, cast, Union, List +if TYPE_CHECKING: + from devlib.target import Target class HotplugModule(Module): - + """ + Kernel ``hotplug`` subsystem allows offlining ("removing") cores from the + system, and onlining them back in. The ``devlib`` module exposes a simple + interface to this subsystem + """ name = 'hotplug' base_path = '/sys/devices/system/cpu' @classmethod - def probe(cls, target): # pylint: disable=arguments-differ + def probe(cls, target: 'Target') -> bool: # pylint: disable=arguments-differ # If a system has just 1 CPU, it makes not sense to hotplug it. # If a system has more than 1 CPU, CPU0 could be configured to be not # hotpluggable. Thus, check for hotplug support by looking at CPU1 path = cls._cpu_path(target, 1) - return target.file_exists(path) and target.is_rooted + return cast(bool, target.file_exists(path) and target.is_rooted) @classmethod - def _cpu_path(cls, target, cpu): + def _cpu_path(cls, target: 'Target', cpu: Union[int, str]) -> str: + """ + get path to cpu online + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - return target.path.join(cls.base_path, cpu, 'online') + return target.path.join(cls.base_path, cpu, 'online') if target.path else '' - def list_hotpluggable_cpus(self): + def list_hotpluggable_cpus(self) -> List[int]: + """ + get the list of hotpluggable cpus + """ return [cpu for cpu in range(self.target.number_of_cpus) if self.target.file_exists(self._cpu_path(self.target, cpu))] - def online_all(self, verify=True): + def online_all(self, verify: bool = True) -> None: + """ + bring all cpus online + """ self.target._execute_util('hotplug_online_all', # pylint: disable=protected-access as_root=self.target.is_rooted) if verify: @@ -48,37 +64,60 @@ def online_all(self, verify=True): if offline: raise TargetTransientError('The following CPUs failed to come back online: {}'.format(offline)) - def online(self, *args): + def online(self, *args) -> None: + """ + bring online specific cpus + """ for cpu in args: self.hotplug(cpu, online=True) - def offline(self, *args): + def offline(self, *args) -> None: + """ + take specific cpus offline + """ for cpu in args: self.hotplug(cpu, online=False) - def hotplug(self, cpu, online): + def hotplug(self, cpu: Union[int, str], online: bool) -> None: + """ + bring cpus online or offline + """ path = self._cpu_path(self.target, cpu) if not self.target.file_exists(path): return value = 1 if online else 0 self.target.write_value(path, value) - def _get_path(self, path): + def _get_path(self, path: str) -> str: + """ + get path to cpu directory + """ return self.target.path.join(self.base_path, - path) + path) if self.target.path else '' - def fail(self, cpu, state): + def fail(self, cpu: Union[str, int], state: str) -> None: + """ + set fail status for cpu hotplug + """ path = self._get_path('cpu{}/hotplug/fail'.format(cpu)) return self.target.write_value(path, state) - def get_state(self, cpu): + def get_state(self, cpu: Union[int, str]) -> str: + """ + get the hotplug state of the cpu + """ path = self._get_path('cpu{}/hotplug/state'.format(cpu)) return self.target.read_value(path) - def get_states(self): - path = self._get_path('hotplug/states') - states_string = self.target.read_value(path) - return dict( - map(str.strip, string.split(':', 1)) - for string in states_string.strip().splitlines() - ) + def get_states(self) -> Dict[str, str]: + """ + get the possible values for hotplug states + """ + path: str = self._get_path('hotplug/states') + states_string: str = self.target.read_value(path) + return { + key.strip(): value.strip() + for line in states_string.strip().splitlines() + if ':' in line + for key, value in [line.split(':', 1)] + } diff --git a/devlib/module/hwmon.py b/devlib/module/hwmon.py index 3ecc55ca9..ccd32f2a7 100644 --- a/devlib/module/hwmon.py +++ b/devlib/module/hwmon.py @@ -1,4 +1,4 @@ -# Copyright 2015-2018 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,10 @@ from devlib import TargetStableError from devlib.module import Module from devlib.utils.types import integer +from typing import (TYPE_CHECKING, Set, Union, cast, DefaultDict, + Dict, List, Match, Optional) +if TYPE_CHECKING: + from devlib.target import Target HWMON_ROOT = '/sys/class/hwmon' @@ -25,39 +29,54 @@ class HwmonSensor(object): - - def __init__(self, device, path, kind, number): + """ + hardware monitoring sensor + """ + def __init__(self, device: 'HwmonDevice', path: str, + kind: str, number: int): self.device = device self.path = path self.kind = kind self.number = number - self.target = self.device.target - self.name = '{}/{}{}'.format(self.device.name, self.kind, self.number) + self.target: 'Target' = self.device.target + self.name: str = '{}/{}{}'.format(self.device.name, self.kind, self.number) self.label = self.name - self.items = set() + self.items: Set[str] = set() - def add(self, item): + def add(self, item: str) -> None: + """ + add item to items set + """ self.items.add(item) if item == 'label': - self.label = self.get('label') + self.label = cast(str, self.get('label')) - def get(self, item): + def get(self, item: str) -> Union[int, str]: + """ + get the value of the item + """ path = self.get_file(item) value = self.target.read_value(path) try: - return integer(value) + return integer(value) except (TypeError, ValueError): return value - def set(self, item, value): - path = self.get_file(item) + def set(self, item: str, value: Union[int, str]) -> None: + """ + set value to the item + """ + path: str = self.get_file(item) self.target.write_value(path, value) - def get_file(self, item): + def get_file(self, item: str) -> str: + """ + get file path + """ if item not in self.items: raise ValueError('item "{}" does not exist for {}'.format(item, self.name)) filename = '{}{}_{}'.format(self.kind, self.number, item) - return self.target.path.join(self.path, filename) + return self.target.path.join(self.path, filename) if self.target.path else '' def __str__(self): if self.name != self.label: @@ -70,34 +89,43 @@ def __str__(self): class HwmonDevice(object): - + """ + Hardware monitor device + """ @property - def sensors(self): - all_sensors = [] + def sensors(self) -> List[HwmonSensor]: + """ + get all the hardware monitoring sensors + """ + all_sensors: List[HwmonSensor] = [] for sensors_of_kind in self._sensors.values(): all_sensors.extend(list(sensors_of_kind.values())) return all_sensors - def __init__(self, target, path, name, fields): + def __init__(self, target: 'Target', path: str, name: str, fields: List[str]): self.target = target self.path = path self.name = name - self._sensors = defaultdict(dict) + self._sensors: DefaultDict[str, Dict[int, HwmonSensor]] = defaultdict(dict) path = self.path - if not path.endswith(self.target.path.sep): - path += self.target.path.sep - for entry in fields: - match = HWMON_FILE_REGEX.search(entry) - if match: - kind = match.group('kind') - number = int(match.group('number')) - item = match.group('item') - if number not in self._sensors[kind]: - sensor = HwmonSensor(self, self.path, kind, number) - self._sensors[kind][number] = sensor - self._sensors[kind][number].add(item) - - def get(self, kind, number=None): + if self.target.path: + if not path.endswith(self.target.path.sep): + path += self.target.path.sep + for entry in fields: + match: Optional[Match[str]] = HWMON_FILE_REGEX.search(entry) + if match: + kind: str = match.group('kind') + number: int = int(match.group('number')) + item: str = match.group('item') + if number not in self._sensors[kind]: + sensor = HwmonSensor(self, self.path, kind, number) + self._sensors[kind][number] = sensor + self._sensors[kind][number].add(item) + + def get(self, kind: str, number: Optional[int] = None) -> Union[List[HwmonSensor], HwmonSensor, None]: + """ + get the hardware monitor sensors of the specified kind + """ if number is None: return [s for _, s in sorted(self._sensors[kind].items(), key=lambda x: x[0])] @@ -111,11 +139,15 @@ def __str__(self): class HwmonModule(Module): - + """ + The hwmon (hardware monitoring) subsystem in Linux is used to monitor various hardware parameters + such as temperature, voltage, and fan speed. This subsystem provides a standardized interface for + accessing sensor data from different hardware components. + """ name = 'hwmon' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: try: target.list_directory(HWMON_ROOT, as_root=target.is_rooted) except TargetStableError: @@ -124,23 +156,29 @@ def probe(target): return True @property - def sensors(self): - all_sensors = [] + def sensors(self) -> List[HwmonSensor]: + """ + hardware monitoring sensors in all hardware monitoring devices + """ + all_sensors: List[HwmonSensor] = [] for device in self.devices: all_sensors.extend(device.sensors) return all_sensors - def __init__(self, target): + def __init__(self, target: 'Target'): super(HwmonModule, self).__init__(target) - self.root = HWMON_ROOT - self.devices = [] + self.root: str = HWMON_ROOT + self.devices: List[HwmonDevice] = [] self.scan() - def scan(self): + def scan(self) -> None: + """ + scan and add devices to the hardware mpnitor module + """ values_tree = self.target.read_tree_values(self.root, depth=3, tar=True) for entry_id, fields in values_tree.items(): - path = self.target.path.join(self.root, entry_id) - name = fields.pop('name', None) + path: str = self.target.path.join(self.root, entry_id) if self.target.path else '' + name: Optional[str] = fields.pop('name', None) if name is None: continue self.logger.debug('Adding device {}'.format(name)) diff --git a/devlib/module/sched.py b/devlib/module/sched.py index e1d526dfd..fc7902acb 100644 --- a/devlib/module/sched.py +++ b/devlib/module/sched.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +16,16 @@ import logging import re -from past.builtins import basestring - from devlib.module import Module -from devlib.utils.misc import memoized +from devlib.utils.misc import memoized, get_logger from devlib.utils.types import boolean from devlib.exception import TargetStableError +from typing import (TYPE_CHECKING, cast, Match, Dict, + Any, List, Pattern, Union, Optional, + Tuple, Set) +if TYPE_CHECKING: + from devlib.target import Target + class SchedProcFSNode(object): """ @@ -29,7 +33,6 @@ class SchedProcFSNode(object): :param nodes: Dictionnary view of the underlying procfs nodes (as returned by devlib.read_tree_values()) - :type nodes: dict Say you want to represent this path/data: @@ -49,30 +52,33 @@ class SchedProcFSNode(object): MC """ - _re_procfs_node = re.compile(r"(?P.*\D)(?P\d+)$") + _re_procfs_node: Pattern[str] = re.compile(r"(?P.*\D)(?P\d+)$") - PACKABLE_ENTRIES = [ + PACKABLE_ENTRIES: List[str] = [ "cpu", "domain", "group" ] @staticmethod - def _ends_with_digits(node): - if not isinstance(node, basestring): + def _ends_with_digits(node: str) -> bool: + """ + returns True if the node ends with digits + """ + if not isinstance(node, str): return False - return re.search(SchedProcFSNode._re_procfs_node, node) != None + return re.search(SchedProcFSNode._re_procfs_node, node) is not None @staticmethod - def _node_digits(node): + def _node_digits(node: str) -> int: """ :returns: The ending digits of the procfs node """ - return int(re.search(SchedProcFSNode._re_procfs_node, node).group("digits")) + return int(cast(Match, re.search(SchedProcFSNode._re_procfs_node, node)).group("digits")) @staticmethod - def _node_name(node): + def _node_name(node: str) -> str: """ :returns: The name of the procfs node """ @@ -83,7 +89,7 @@ def _node_name(node): return node @classmethod - def _packable(cls, node): + def _packable(cls, node: str) -> bool: """ :returns: Whether it makes sense to pack a node into a common entry """ @@ -91,14 +97,18 @@ def _packable(cls, node): SchedProcFSNode._node_name(node) in cls.PACKABLE_ENTRIES) @staticmethod - def _build_directory(node_name, node_data): + def _build_directory(node_name: str, + node_data: Any) -> Union['SchedDomain', 'SchedProcFSNode']: + """ + create a new sched domain or a new procfs node + """ if node_name.startswith("domain"): return SchedDomain(node_data) else: return SchedProcFSNode(node_data) @staticmethod - def _build_entry(node_data): + def _build_entry(node_data: Any) -> Union[int, Any]: value = node_data # Most nodes just contain numerical data, try to convert @@ -110,32 +120,33 @@ def _build_entry(node_data): return value @staticmethod - def _build_node(node_name, node_data): + def _build_node(node_name: str, node_data: Any) -> Union['SchedDomain', 'SchedProcFSNode', + int, Any]: if isinstance(node_data, dict): return SchedProcFSNode._build_directory(node_name, node_data) else: return SchedProcFSNode._build_entry(node_data) - def __getattr__(self, name): + def __getattr__(self, name: str): return self._dyn_attrs[name] - def __init__(self, nodes): + def __init__(self, nodes: Dict[str, 'SchedProcFSNode']): self.procfs = nodes # First, reduce the procs fields by packing them if possible # Find which entries can be packed into a common entry - packables = { - node : SchedProcFSNode._node_name(node) + "s" - for node in list(nodes.keys()) if SchedProcFSNode._packable(node) + packables: Dict[str, str] = { + node: SchedProcFSNode._node_name(node) + "s" + for node in list(cast(SchedProcFSNode, nodes).keys()) if SchedProcFSNode._packable(node) } - self._dyn_attrs = {} + self._dyn_attrs: Dict[str, Any] = {} for dest in set(packables.values()): self._dyn_attrs[dest] = {} # Pack common entries for key, dest in packables.items(): - i = SchedProcFSNode._node_digits(key) + i: int = SchedProcFSNode._node_digits(key) self._dyn_attrs[dest][i] = self._build_node(key, nodes[key]) # Build the other nodes @@ -153,13 +164,15 @@ class _SchedDomainFlag: exposed. """ - _INSTANCES = {} + _INSTANCES: Dict['_SchedDomainFlag', '_SchedDomainFlag'] = {} """ Dictionary storing the instances so that they can be compared with ``is`` operator. """ + name: str + _value: Optional[int] - def __new__(cls, name, value, doc=None): + def __new__(cls, name: str, value: Optional[int], doc: Optional[str] = None): self = super().__new__(cls) self.name = name self._value = value @@ -175,7 +188,7 @@ def __hash__(self): return hash((self.name, self._value)) @property - def value(self): + def value(self) -> Optional[int]: value = self._value if value is None: raise AttributeError('The kernel does not expose the sched domain flag values') @@ -183,14 +196,14 @@ def value(self): return value @staticmethod - def check_version(target, logger): + def check_version(target: 'Target', logger: logging.Logger) -> None: """ Check the target and see if its kernel version matches our view of the world """ - parts = target.kernel_version.parts + parts: Tuple[Optional[int], Optional[int], Optional[int]] = target.kernel_version.parts # Checked to be valid from v4.4 # Not saved as a class attribute else it'll be converted to an enum - ref_parts = (4, 4, 0) + ref_parts: Tuple[int, int, int] = (4, 4, 0) if parts < ref_parts: logger.warn( "Sched domain flags are defined for kernels v{} and up, " @@ -212,7 +225,7 @@ class _SchedDomainFlagMeta(type): backward compatibility. """ @property - def _flags(self): + def _flags(self) -> List[Any]: return [ attr for name, attr in self.__dict__.items() @@ -280,10 +293,10 @@ class SchedDomain(SchedProcFSNode): """ Represents a sched domain as seen through procfs """ - def __init__(self, nodes): + def __init__(self, nodes: Dict[str, SchedProcFSNode]): super().__init__(nodes) - flags = self.flags + flags: Union[Set[_SchedDomainFlag], str] = self.flags # Recent kernels now have a space-separated list of flags instead of a # packed bitfield if isinstance(flags, str): @@ -292,8 +305,8 @@ def __init__(self, nodes): for name in flags.split() } else: - def has_flag(flags, flag): - return flags & flag.value == flag.value + def has_flag(flags: Set[_SchedDomainFlag], flag: _SchedDomainFlag): + return any(f.value == flag.value for f in flags) flags = { flag @@ -303,69 +316,79 @@ def has_flag(flags, flag): self.flags = flags -def _select_path(target, paths, name): + +def _select_path(target: 'Target', paths: List[str], name: str) -> str: + """ + select existing file path + """ for p in paths: if target.file_exists(p): return p raise TargetStableError('No {} found. Tried: {}'.format(name, ', '.join(paths))) + class SchedProcFSData(SchedProcFSNode): """ Root class for creating & storing SchedProcFSNode instances """ - _read_depth = 6 + _read_depth: int = 6 @classmethod - def get_data_root(cls, target): + def get_data_root(cls, target: 'Target'): # Location differs depending on kernel version paths = ['/sys/kernel/debug/sched/domains/', '/proc/sys/kernel/sched_domain'] return _select_path(target, paths, "sched_domain debug directory") @staticmethod - def available(target): + def available(target: 'Target') -> bool: + """ + check availability of sched domains + """ try: path = SchedProcFSData.get_data_root(target) except TargetStableError: return False - cpus = target.list_directory(path, as_root=target.is_rooted) + cpus: List[str] = target.list_directory(path, as_root=target.is_rooted) if not cpus: return False # Even if we have a CPU entry, it can be empty (e.g. hotplugged out) # Make sure some data is there for cpu in cpus: - if target.file_exists(target.path.join(path, cpu, "domain0", "flags")): + if target.file_exists(target.path.join(path, cpu, "domain0", "flags") if target.path else ''): return True return False - def __init__(self, target, path=None): + def __init__(self, target: 'Target', path: Optional[str] = None): if path is None: path = SchedProcFSData.get_data_root(target) - procfs = target.read_tree_values(path, depth=self._read_depth) + procfs: Dict[str, 'SchedProcFSNode'] = target.read_tree_values(path, depth=self._read_depth) super(SchedProcFSData, self).__init__(procfs) class SchedModule(Module): + """ + scheduler module + """ + name: str = 'sched' - name = 'sched' - - cpu_sysfs_root = '/sys/devices/system/cpu' + cpu_sysfs_root: str = '/sys/devices/system/cpu' @staticmethod - def probe(target): - logger = logging.getLogger(SchedModule.name) + def probe(target: 'Target') -> bool: + logger: logging.Logger = get_logger(SchedModule.name) SchedDomainFlag.check_version(target, logger) # It makes sense to load this module if at least one of those # functionalities is enabled - schedproc = SchedProcFSData.available(target) - debug = SchedModule.target_has_debug(target) - dmips = any([target.file_exists(SchedModule.cpu_dmips_capacity_path(target, cpu)) - for cpu in target.list_online_cpus()]) + schedproc: bool = SchedProcFSData.available(target) + debug: bool = SchedModule.target_has_debug(target) + dmips: bool = any([target.file_exists(SchedModule.cpu_dmips_capacity_path(target, cpu)) + for cpu in target.list_online_cpus()]) logger.info("Scheduler sched_domain procfs entries %s", "found" if schedproc else "not found") @@ -376,16 +399,17 @@ def probe(target): return schedproc or debug or dmips - def __init__(self, target): + def __init__(self, target: 'Target'): super().__init__(target) @classmethod - def get_sched_features_path(cls, target): + def get_sched_features_path(cls, target: 'Target') -> str: # Location differs depending on kernel version - paths = ['/sys/kernel/debug/sched/features', '/sys/kernel/debug/sched_features'] + paths: List[str] = ['/sys/kernel/debug/sched/features', '/sys/kernel/debug/sched_features'] return _select_path(target, paths, "sched_features file") - def get_kernel_attributes(self, matching=None, check_exit_code=True): + def get_kernel_attributes(self, matching: Optional[str] = None, + check_exit_code: bool = True) -> Dict[str, Union[int, bool]]: """ Get the value of scheduler attributes. @@ -406,21 +430,22 @@ def get_kernel_attributes(self, matching=None, check_exit_code=True): command = 'sched_get_kernel_attributes {}'.format( matching if matching else '' ) - output = self.target._execute_util(command, as_root=self.target.is_rooted, - check_exit_code=check_exit_code) - result = {} + output: str = self.target._execute_util(command, as_root=self.target.is_rooted, + check_exit_code=check_exit_code) + result: Dict[str, Union[int, bool]] = {} for entry in output.strip().split('\n'): if ':' not in entry: continue - path, value = entry.strip().split(':', 1) - if value in ['0', '1']: - value = bool(int(value)) - elif value.isdigit(): - value = int(value) + path, value_s = entry.strip().split(':', 1) + if value_s in ['0', '1']: + value: Union[int, bool] = bool(int(value_s)) + elif value_s.isdigit(): + value = int(value_s) result[path] = value return result - def set_kernel_attribute(self, attr, value, verify=True): + def set_kernel_attribute(self, attr: str, value: Union[bool, int, str], + verify: bool = True) -> None: """ Set the value of a scheduler attribute. @@ -434,11 +459,14 @@ def set_kernel_attribute(self, attr, value, verify=True): value = '1' if value else '0' elif isinstance(value, int): value = str(value) - path = '/proc/sys/kernel/sched_' + attr + path: str = '/proc/sys/kernel/sched_' + attr self.target.write_value(path, value, verify) @classmethod - def target_has_debug(cls, target): + def target_has_debug(cls, target: 'Target') -> bool: + """ + True if target has SCHED_DEBUG config set and has sched features + """ if target.config.get('SCHED_DEBUG') != 'y': return False @@ -448,23 +476,23 @@ def target_has_debug(cls, target): except TargetStableError: return False - def get_features(self): + def get_features(self) -> Dict[str, bool]: """ Get the status of each sched feature :returns: a dictionary of features and their "is enabled" status """ - feats = self.target.read_value(self.get_sched_features_path(self.target)) - features = {} + feats: str = self.target.read_value(self.get_sched_features_path(self.target)) + features: Dict[str, bool] = {} for feat in feats.split(): - value = True + value: bool = True if feat.startswith('NO'): feat = feat.replace('NO_', '', 1) value = False features[feat] = value return features - def set_feature(self, feature, enable, verify=True): + def set_feature(self, feature: str, enable: bool, verify: bool = True): """ Set the status of a specified scheduler feature @@ -475,63 +503,63 @@ def set_feature(self, feature, enable, verify=True): :raise RuntimeError: if the specified feature cannot be set """ feature = feature.upper() - feat_value = feature + feat_value: str = feature if not boolean(enable): feat_value = 'NO_' + feat_value self.target.write_value(self.get_sched_features_path(self.target), feat_value, verify=False) if not verify: return - msg = 'Failed to set {}, feature not supported?'.format(feat_value) - features = self.get_features() - feat_value = features.get(feature, not enable) - if feat_value != enable: + msg: str = 'Failed to set {}, feature not supported?'.format(feat_value) + features: Dict[str, bool] = self.get_features() + feat_ = features.get(feature, not enable) + if feat_ != enable: raise RuntimeError(msg) - def get_cpu_sd_info(self, cpu): + def get_cpu_sd_info(self, cpu: int) -> SchedProcFSData: """ :returns: An object view of the sched_domain debug directory of 'cpu' """ path = self.target.path.join( SchedProcFSData.get_data_root(self.target), "cpu{}".format(cpu) - ) + ) if self.target.path else '' return SchedProcFSData(self.target, path) - def get_sd_info(self): + def get_sd_info(self) -> SchedProcFSData: """ :returns: An object view of the entire sched_domain debug directory """ return SchedProcFSData(self.target) - def get_capacity(self, cpu): + def get_capacity(self, cpu: int) -> int: """ :returns: The capacity of 'cpu' """ return self.get_capacities()[cpu] @memoized - def has_em(self, cpu, sd=None): + def has_em(self, cpu: int, sd: Optional[SchedProcFSData] = None) -> bool: """ :returns: Whether energy model data is available for 'cpu' """ if not sd: sd = self.get_cpu_sd_info(cpu) - return sd.procfs["domain0"].get("group0", {}).get("energy", {}).get("cap_states") != None + return sd.procfs["domain0"].get("group0", {}).get("energy", {}).get("cap_states") is not None @classmethod - def cpu_dmips_capacity_path(cls, target, cpu): + def cpu_dmips_capacity_path(cls, target: 'Target', cpu: int): """ :returns: The target sysfs path where the dmips capacity data should be """ return target.path.join( cls.cpu_sysfs_root, - 'cpu{}/cpu_capacity'.format(cpu)) + 'cpu{}/cpu_capacity'.format(cpu)) if target.path else '' @memoized - def has_dmips_capacity(self, cpu): + def has_dmips_capacity(self, cpu: int) -> bool: """ :returns: Whether dmips capacity data is available for 'cpu' """ @@ -540,21 +568,21 @@ def has_dmips_capacity(self, cpu): ) @memoized - def get_em_capacity(self, cpu, sd=None): + def get_em_capacity(self, cpu: int, sd: Optional[SchedProcFSData] = None) -> int: """ :returns: The maximum capacity value exposed by the EAS energy model """ if not sd: sd = self.get_cpu_sd_info(cpu) - cap_states = sd.domains[0].groups[0].energy.cap_states - cap_states_list = cap_states.split('\t') - num_cap_states = sd.domains[0].groups[0].energy.nr_cap_states - max_cap_index = -1 * int(len(cap_states_list) / num_cap_states) + cap_states: str = sd.domains[0].groups[0].energy.cap_states + cap_states_list: List[str] = cap_states.split('\t') + num_cap_states: int = sd.domains[0].groups[0].energy.nr_cap_states + max_cap_index: int = -1 * int(len(cap_states_list) / num_cap_states) return int(cap_states_list[max_cap_index]) @memoized - def get_dmips_capacity(self, cpu): + def get_dmips_capacity(self, cpu: int) -> int: """ :returns: The capacity value generated from the capacity-dmips-mhz DT entry """ @@ -562,7 +590,7 @@ def get_dmips_capacity(self, cpu): self.cpu_dmips_capacity_path(self.target, cpu), int ) - def get_capacities(self, default=None): + def get_capacities(self, default: Optional[int] = None) -> Dict[int, int]: """ :param default: Default capacity value to find if no data is found in procfs @@ -572,41 +600,41 @@ def get_capacities(self, default=None): :raises RuntimeError: Raised when no capacity information is found and 'default' is None """ - cpus = self.target.list_online_cpus() + cpus: List[int] = self.target.list_online_cpus() - capacities = {} + capacities: Dict[int, int] = {} for cpu in cpus: if self.has_dmips_capacity(cpu): capacities[cpu] = self.get_dmips_capacity(cpu) - missing_cpus = set(cpus).difference(capacities.keys()) + missing_cpus: Set[int] = set(cpus).difference(capacities.keys()) if not missing_cpus: return capacities if not SchedProcFSData.available(self.target): - if default != None: - capacities.update({cpu : default for cpu in missing_cpus}) + if default is not None: + capacities.update({cpu: cast(int, default) for cpu in missing_cpus}) return capacities else: raise RuntimeError( 'No capacity data for cpus {}'.format(sorted(missing_cpus))) - sd_info = self.get_sd_info() + sd_info: SchedProcFSData = self.get_sd_info() for cpu in missing_cpus: if self.has_em(cpu, sd_info.cpus[cpu]): capacities[cpu] = self.get_em_capacity(cpu, sd_info.cpus[cpu]) else: - if default != None: - capacities[cpu] = default + if default is not None: + capacities[cpu] = cast(int, default) else: raise RuntimeError('No capacity data for cpu{}'.format(cpu)) return capacities @memoized - def get_hz(self): + def get_hz(self) -> int: """ :returns: The scheduler tick frequency on the target """ - return int(self.target.config.get('CONFIG_HZ', strict=True)) + return int(cast(str, self.target.config.get('CONFIG_HZ', strict=True))) diff --git a/devlib/module/thermal.py b/devlib/module/thermal.py index d23739ea2..6a92a9c60 100644 --- a/devlib/module/thermal.py +++ b/devlib/module/thermal.py @@ -1,4 +1,4 @@ -# Copyright 2015-2018 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,139 +16,174 @@ import logging import devlib.utils.asyn as asyn +from devlib.utils.misc import get_logger from devlib.module import Module from devlib.exception import TargetStableCalledProcessError +from typing import (TYPE_CHECKING, Dict, Match, Optional, + Tuple, List) +if TYPE_CHECKING: + from devlib.target import Target + class TripPoint(object): - def __init__(self, zone, _id): + """ + Trip points are predefined temperature thresholds within a thermal zone. When the temperature reaches these points, + specific actions are triggered to manage the system's thermal state. There are typically three types of trip points: + + Active Trip Points: Trigger active cooling mechanisms like fans when the temperature exceeds a certain threshold. + Passive Trip Points: Initiate passive cooling strategies, such as reducing the processor's clock speed, to lower the temperature. + Critical Trip Points: Indicate a critical temperature level that requires immediate action, such as shutting down the system to prevent damage + """ + def __init__(self, zone: 'ThermalZone', _id: str): self._id = _id self.zone = zone - self.temp_node = 'trip_point_' + _id + '_temp' - self.type_node = 'trip_point_' + _id + '_type' + self.temp_node: str = 'trip_point_' + _id + '_temp' + self.type_node: str = 'trip_point_' + _id + '_type' @property - def target(self): + def target(self) -> 'Target': + """ + target of the trip point + """ return self.zone.target @asyn.asyncf - async def get_temperature(self): + async def get_temperature(self) -> int: """Returns the currently configured temperature of the trip point""" - temp_file = self.target.path.join(self.zone.path, self.temp_node) + temp_file: str = self.target.path.join(self.zone.path, self.temp_node) if self.target.path else '' return await self.target.read_int.asyn(temp_file) @asyn.asyncf - async def set_temperature(self, temperature): - temp_file = self.target.path.join(self.zone.path, self.temp_node) + async def set_temperature(self, temperature: int) -> None: + """ + set temperature threshold for the trip point + """ + temp_file: str = self.target.path.join(self.zone.path, self.temp_node) if self.target.path else '' await self.target.write_value.asyn(temp_file, temperature) @asyn.asyncf - async def get_type(self): + async def get_type(self) -> str: """Returns the type of trip point""" - type_file = self.target.path.join(self.zone.path, self.type_node) + type_file: str = self.target.path.join(self.zone.path, self.type_node) if self.target.path else '' return await self.target.read_value.asyn(type_file) + class ThermalZone(object): - def __init__(self, target, root, _id): + """ + A thermal zone is a logical collection of interfaces to temperature sensors, trip points, + thermal property information, and thermal controls. These zones help manage the temperature + of various components within a system, such as CPUs, GPUs, and other hardware. + """ + def __init__(self, target: 'Target', root: str, _id: str): self.target = target self.name = 'thermal_zone' + _id - self.path = target.path.join(root, self.name) - self.trip_points = {} - self.type = self.target.read_value(self.target.path.join(self.path, 'type')) + self.path = target.path.join(root, self.name) if target.path else '' + self.trip_points: Dict[int, TripPoint] = {} + self.type: str = self.target.read_value(self.target.path.join(self.path, 'type') if self.target.path else '') for entry in self.target.list_directory(self.path, as_root=target.is_rooted): - re_match = re.match('^trip_point_([0-9]+)_temp', entry) + re_match: Optional[Match[str]] = re.match('^trip_point_([0-9]+)_temp', entry) if re_match is not None: self._add_trip_point(re_match.group(1)) - def _add_trip_point(self, _id): + def _add_trip_point(self, _id: str) -> None: + """ + add a trip point to the thermal zone + """ self.trip_points[int(_id)] = TripPoint(self, _id) @asyn.asyncf - async def is_enabled(self): + async def is_enabled(self) -> bool: """Returns a boolean representing the 'mode' of the thermal zone""" - value = await self.target.read_value.asyn(self.target.path.join(self.path, 'mode')) + value: str = await self.target.read_value.asyn(self.target.path.join(self.path, 'mode') if self.target.path else '') return value == 'enabled' @asyn.asyncf - async def set_enabled(self, enabled=True): + async def set_enabled(self, enabled: bool = True) -> None: + """ + enable or disable the thermal zone + """ value = 'enabled' if enabled else 'disabled' - await self.target.write_value.asyn(self.target.path.join(self.path, 'mode'), value) + await self.target.write_value.asyn(self.target.path.join(self.path, 'mode') if self.target.path else '', value) @asyn.asyncf - async def get_temperature(self): + async def get_temperature(self) -> int: """Returns the temperature of the thermal zone""" - sysfs_temperature_file = self.target.path.join(self.path, 'temp') + sysfs_temperature_file = self.target.path.join(self.path, 'temp') if self.target.path else '' return await self.target.read_int.asyn(sysfs_temperature_file) @asyn.asyncf - async def get_policy(self): + async def get_policy(self) -> str: """Returns the policy of the thermal zone""" - temp_file = self.target.path.join(self.path, 'policy') + temp_file = self.target.path.join(self.path, 'policy') if self.target.path else '' return await self.target.read_value.asyn(temp_file) @asyn.asyncf - async def set_policy(self, policy): + async def set_policy(self, policy: str) -> None: """ Sets the policy of the thermal zone :params policy: Thermal governor name - :type policy: str """ - await self.target.write_value.asyn(self.target.path.join(self.path, 'policy'), policy) + await self.target.write_value.asyn(self.target.path.join(self.path, 'policy') if self.target.path else '', policy) @asyn.asyncf - async def get_offset(self): + async def get_offset(self) -> int: """Returns the temperature offset of the thermal zone""" - offset_file = self.target.path.join(self.path, 'offset') + offset_file: str = self.target.path.join(self.path, 'offset') if self.target.path else '' return await self.target.read_value.asyn(offset_file) @asyn.asyncf - async def set_offset(self, offset): + async def set_offset(self, offset: int) -> None: """ Sets the temperature offset in milli-degrees of the thermal zone :params offset: Temperature offset in milli-degrees - :type policy: int """ - await self.target.write_value.asyn(self.target.path.join(self.path, 'offset'), policy) + await self.target.write_value.asyn(self.target.path.join(self.path, 'offset') if self.target.path else '', offset) @asyn.asyncf - async def set_emul_temp(self, offset): + async def set_emul_temp(self, offset: int) -> None: """ Sets the emulated temperature in milli-degrees of the thermal zone :params offset: Emulated temperature in milli-degrees - :type policy: int """ - await self.target.write_value.asyn(self.target.path.join(self.path, 'emul_temp'), policy) + await self.target.write_value.asyn(self.target.path.join(self.path, 'emul_temp') if self.target.path else '', offset) @asyn.asyncf - async def get_available_policies(self): + async def get_available_policies(self) -> str: """Returns the policies available for the thermal zone""" - temp_file = self.target.path.join(self.path, 'available_policies') + temp_file: str = self.target.path.join(self.path, 'available_policies') if self.target.path else '' return await self.target.read_value.asyn(temp_file) + class ThermalModule(Module): + """ + The /sys/class/thermal directory in Linux provides a sysfs interface for thermal management. + This directory contains subdirectories and files that represent thermal zones and cooling devices, + allowing users and applications to monitor and manage system temperatures. + """ name = 'thermal' thermal_root = '/sys/class/thermal' @staticmethod - def probe(target): - + def probe(target: 'Target') -> bool: if target.file_exists(ThermalModule.thermal_root): return True + return False - def __init__(self, target): + def __init__(self, target: 'Target'): super(ThermalModule, self).__init__(target) - self.logger = logging.getLogger(self.name) + self.logger: logging.Logger = get_logger(self.name) self.logger.debug('Initialized [%s] module', self.name) - self.zones = {} - self.cdevs = [] + self.zones: Dict[int, ThermalZone] = {} + self.cdevs: List = [] for entry in target.list_directory(self.thermal_root): - re_match = re.match('^(thermal_zone|cooling_device)([0-9]+)', entry) + re_match: Optional[Match[str]] = re.match('^(thermal_zone|cooling_device)([0-9]+)', entry) if not re_match: self.logger.warning('unknown thermal entry: %s', entry) continue @@ -159,29 +194,28 @@ def __init__(self, target): # TODO pass - def _add_thermal_zone(self, _id): + def _add_thermal_zone(self, _id: str) -> None: self.zones[int(_id)] = ThermalZone(self.target, self.thermal_root, _id) - def disable_all_zones(self): + def disable_all_zones(self) -> None: """Disables all the thermal zones in the target""" for zone in self.zones.values(): zone.set_enabled(False) @asyn.asyncf - async def get_all_temperatures(self, error='raise'): + async def get_all_temperatures(self, error: str = 'raise') -> Dict[str, int]: """ Returns dictionary with current reading of all thermal zones. :params error: Sensor read error handling (raise or ignore) - :type error: str :returns: a dictionary in the form: {tz_type:temperature} """ - async def get_temperature_noexcep(item): + async def get_temperature_noexcep(item: Tuple[str, ThermalZone]) -> Optional[int]: tzid, tz = item try: - temperature = await tz.get_temperature.asyn() + temperature: int = await tz.get_temperature.asyn() except TargetStableCalledProcessError as e: if error == 'raise': raise e diff --git a/devlib/module/vexpress.py b/devlib/module/vexpress.py index c597747be..f44d99ab1 100644 --- a/devlib/module/vexpress.py +++ b/devlib/module/vexpress.py @@ -1,5 +1,5 @@ # -# Copyright 2015-2018 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,25 +25,35 @@ from devlib.utils.serial_port import open_serial_connection, pulse_dtr, write_characters from devlib.utils.uefi import UefiMenu, UefiConfig from devlib.utils.uboot import UbootMenu +from devlib.platform.arm import VersatileExpressPlatform +# pylint: disable=ungrouped-imports +try: + from pexpect import fdpexpect +# pexpect < 4.0.0 does not have fdpexpect module +except ImportError: + import fdpexpect # type:ignore +from typing import TYPE_CHECKING, cast, Optional, Dict, Union, Any +if TYPE_CHECKING: + from devlib.target import Target -OLD_AUTOSTART_MESSAGE = 'Press Enter to stop auto boot...' -AUTOSTART_MESSAGE = 'Hit any key to stop autoboot:' -POWERUP_MESSAGE = 'Powering up system...' -DEFAULT_MCC_PROMPT = 'Cmd>' +OLD_AUTOSTART_MESSAGE: str = 'Press Enter to stop auto boot...' +AUTOSTART_MESSAGE: str = 'Hit any key to stop autoboot:' +POWERUP_MESSAGE: str = 'Powering up system...' +DEFAULT_MCC_PROMPT: str = 'Cmd>' class VexpressDtrHardReset(HardRestModule): - name = 'vexpress-dtr' - stage = 'early' + name: str = 'vexpress-dtr' + stage: str = 'early' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return True - def __init__(self, target, port='/dev/ttyS0', baudrate=115200, - mcc_prompt=DEFAULT_MCC_PROMPT, timeout=300): + def __init__(self, target: 'Target', port: str = '/dev/ttyS0', baudrate: int = 115200, + mcc_prompt: str = DEFAULT_MCC_PROMPT, timeout: int = 300): super(VexpressDtrHardReset, self).__init__(target) self.port = port self.baudrate = baudrate @@ -59,7 +69,7 @@ def __call__(self): with open_serial_connection(port=self.port, baudrate=self.baudrate, timeout=self.timeout, - init_dtr=0, + init_dtr=False, get_conn=True) as (_, conn): pulse_dtr(conn, state=True, duration=0.1) # TRM specifies a pulse of >=100ms @@ -70,13 +80,13 @@ class VexpressReboottxtHardReset(HardRestModule): stage = 'early' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return True - def __init__(self, target, - port='/dev/ttyS0', baudrate=115200, - path='/media/VEMSD', - mcc_prompt=DEFAULT_MCC_PROMPT, timeout=30, short_delay=1): + def __init__(self, target: 'Target', + port: str = '/dev/ttyS0', baudrate: int = 115200, + path: str = '/media/VEMSD', + mcc_prompt: str = DEFAULT_MCC_PROMPT, timeout: int = 30, short_delay: int = 1): super(VexpressReboottxtHardReset, self).__init__(target) self.port = port self.baudrate = baudrate @@ -98,7 +108,7 @@ def __call__(self): with open_serial_connection(port=self.port, baudrate=self.baudrate, timeout=self.timeout, - init_dtr=0) as tty: + init_dtr=False) as tty: wait_for_vemsd(self.path, tty, self.mcc_prompt, self.short_delay) with open(self.filepath, 'w'): pass @@ -109,13 +119,13 @@ class VexpressBootModule(BootModule): stage = 'early' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return True - def __init__(self, target, uefi_entry=None, - port='/dev/ttyS0', baudrate=115200, - mcc_prompt=DEFAULT_MCC_PROMPT, - timeout=120, short_delay=1): + def __init__(self, target: 'Target', uefi_entry: Optional[str] = None, + port: str = '/dev/ttyS0', baudrate: int = 115200, + mcc_prompt: str = DEFAULT_MCC_PROMPT, + timeout: int = 120, short_delay: int = 1): super(VexpressBootModule, self).__init__(target) self.port = port self.baudrate = baudrate @@ -128,18 +138,24 @@ def __call__(self): with open_serial_connection(port=self.port, baudrate=self.baudrate, timeout=self.timeout, - init_dtr=0) as tty: + init_dtr=False) as tty: self.get_through_early_boot(tty) self.perform_boot_sequence(tty) self.wait_for_shell_prompt(tty) - def perform_boot_sequence(self, tty): + def perform_boot_sequence(self, tty: fdpexpect.fdspawn) -> None: + """ + boot up the vexpress + """ raise NotImplementedError() - def get_through_early_boot(self, tty): + def get_through_early_boot(self, tty: fdpexpect.fdspawn) -> None: + """ + do the things necessary during early boot + """ self.logger.debug('Establishing initial state...') tty.sendline('') - i = tty.expect([AUTOSTART_MESSAGE, OLD_AUTOSTART_MESSAGE, POWERUP_MESSAGE, self.mcc_prompt]) + i: int = tty.expect([AUTOSTART_MESSAGE, OLD_AUTOSTART_MESSAGE, POWERUP_MESSAGE, self.mcc_prompt]) if i == 3: self.logger.debug('Saw MCC prompt.') time.sleep(self.short_delay) @@ -154,13 +170,13 @@ def get_through_early_boot(self, tty): tty.sendline('reboot') tty.sendline('reset') - def get_uefi_menu(self, tty): + def get_uefi_menu(self, tty: fdpexpect.fdspawn) -> UefiMenu: menu = UefiMenu(tty) self.logger.debug('Waiting for UEFI menu...') menu.wait(timeout=self.timeout) return menu - def wait_for_shell_prompt(self, tty): + def wait_for_shell_prompt(self, tty: fdpexpect.fdspawn) -> None: self.logger.debug('Waiting for the shell prompt.') tty.expect(self.target.shell_prompt, timeout=self.timeout) # This delay is needed to allow the platform some time to finish @@ -171,17 +187,17 @@ def wait_for_shell_prompt(self, tty): class VexpressUefiBoot(VexpressBootModule): - name = 'vexpress-uefi' + name: str = 'vexpress-uefi' - def __init__(self, target, uefi_entry, - image, fdt, bootargs, initrd, + def __init__(self, target: 'Target', uefi_entry: Optional[str], + image: str, fdt: str, bootargs: str, initrd: str, *args, **kwargs): - super(VexpressUefiBoot, self).__init__(target, uefi_entry=uefi_entry, + super(VexpressUefiBoot, self).__init__(target, uefi_entry, *args, **kwargs) - self.uefi_config = self._create_config(image, fdt, bootargs, initrd) + self.uefi_config: UefiConfig = self._create_config(image, fdt, bootargs, initrd) - def perform_boot_sequence(self, tty): - menu = self.get_uefi_menu(tty) + def perform_boot_sequence(self, tty: fdpexpect.fdspawn) -> None: + menu: UefiMenu = self.get_uefi_menu(tty) try: menu.select(self.uefi_entry) except LookupError: @@ -190,8 +206,8 @@ def perform_boot_sequence(self, tty): menu.create_entry(self.uefi_entry, self.uefi_config) menu.select(self.uefi_entry) - def _create_config(self, image, fdt, bootargs, initrd): # pylint: disable=R0201 - config_dict = { + def _create_config(self, image: str, fdt: str, bootargs: str, initrd: str): # pylint: disable=R0201 + config_dict: Dict[str, Union[str, bool]] = { 'image_name': image, 'image_args': bootargs, 'initrd': initrd, @@ -208,21 +224,21 @@ def _create_config(self, image, fdt, bootargs, initrd): # pylint: disable=R0201 class VexpressUefiShellBoot(VexpressBootModule): - name = 'vexpress-uefi-shell' + name: str = 'vexpress-uefi-shell' # pylint: disable=keyword-arg-before-vararg - def __init__(self, target, uefi_entry='^Shell$', - efi_shell_prompt='Shell>', - image='kernel', bootargs=None, + def __init__(self, target: 'Target', uefi_entry: Optional[str] = '^Shell$', + efi_shell_prompt: str = 'Shell>', + image: str = 'kernel', bootargs: Optional[str] = None, *args, **kwargs): - super(VexpressUefiShellBoot, self).__init__(target, uefi_entry=uefi_entry, + super(VexpressUefiShellBoot, self).__init__(target, uefi_entry, *args, **kwargs) self.efi_shell_prompt = efi_shell_prompt self.image = image self.bootargs = bootargs - def perform_boot_sequence(self, tty): - menu = self.get_uefi_menu(tty) + def perform_boot_sequence(self, tty: fdpexpect.fdspawn) -> None: + menu: UefiMenu = self.get_uefi_menu(tty) try: menu.select(self.uefi_entry) except LookupError: @@ -239,15 +255,15 @@ def perform_boot_sequence(self, tty): class VexpressUBoot(VexpressBootModule): - name = 'vexpress-u-boot' + name: str = 'vexpress-u-boot' # pylint: disable=keyword-arg-before-vararg - def __init__(self, target, env=None, + def __init__(self, target: 'Target', env: Optional[Dict] = None, *args, **kwargs): super(VexpressUBoot, self).__init__(target, *args, **kwargs) self.env = env - def perform_boot_sequence(self, tty): + def perform_boot_sequence(self, tty: fdpexpect.fdspawn) -> None: if self.env is None: return # Will boot automatically @@ -261,13 +277,13 @@ def perform_boot_sequence(self, tty): class VexpressBootmon(VexpressBootModule): - name = 'vexpress-bootmon' + name: str = 'vexpress-bootmon' # pylint: disable=keyword-arg-before-vararg - def __init__(self, target, - image, fdt, initrd, bootargs, - uses_bootscript=False, - bootmon_prompt='>', + def __init__(self, target: 'Target', + image: str, fdt: str, initrd: str, bootargs: str, + uses_bootscript: bool = False, + bootmon_prompt: str = '>', *args, **kwargs): super(VexpressBootmon, self).__init__(target, *args, **kwargs) self.image = image @@ -277,7 +293,7 @@ def __init__(self, target, self.uses_bootscript = uses_bootscript self.bootmon_prompt = bootmon_prompt - def perform_boot_sequence(self, tty): + def perform_boot_sequence(self, tty: fdpexpect.fdspawn) -> None: if self.uses_bootscript: return # Will boot automatically @@ -286,7 +302,7 @@ def perform_boot_sequence(self, tty): with open_serial_connection(port=self.port, baudrate=self.baudrate, timeout=self.timeout, - init_dtr=0) as tty_conn: + init_dtr=False) as tty_conn: write_characters(tty_conn, 'fl linux fdt {}'.format(self.fdt)) write_characters(tty_conn, 'fl linux initrd {}'.format(self.initrd)) write_characters(tty_conn, 'fl linux boot {} {}'.format(self.image, @@ -295,8 +311,8 @@ def perform_boot_sequence(self, tty): class VersatileExpressFlashModule(FlashModule): - name = 'vexpress-vemsd' - description = """ + name: str = 'vexpress-vemsd' + description: str = """ Enables flashing of kernels and firmware to ARM Versatile Express devices. This modules enables flashing of image bundles or individual images to ARM @@ -311,31 +327,34 @@ class VersatileExpressFlashModule(FlashModule): """ - stage = 'early' + stage: str = 'early' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: if not target.has('hard_reset'): return False return True - def __init__(self, target, vemsd_mount, mcc_prompt=DEFAULT_MCC_PROMPT, timeout=30, short_delay=1): + def __init__(self, target: 'Target', vemsd_mount: str, + mcc_prompt: str = DEFAULT_MCC_PROMPT, timeout: int = 30, short_delay: int = 1): super(VersatileExpressFlashModule, self).__init__(target) self.vemsd_mount = vemsd_mount self.mcc_prompt = mcc_prompt self.timeout = timeout self.short_delay = short_delay - def __call__(self, image_bundle=None, images=None, bootargs=None, connect=True): - self.target.hard_reset() - with open_serial_connection(port=self.target.platform.serial_port, - baudrate=self.target.platform.baudrate, + def __call__(self, image_bundle: Optional[str] = None, + images: Optional[Dict[str, str]] = None, + bootargs: Any = None, connect: bool = True): + cast(HardRestModule, self.target.hard_reset)() + with open_serial_connection(port=cast(VersatileExpressPlatform, self.target.platform).serial_port, + baudrate=cast(VersatileExpressPlatform, self.target.platform).baudrate, timeout=self.timeout, - init_dtr=0) as tty: + init_dtr=False) as tty: # pylint: disable=no-member - i = tty.expect([self.mcc_prompt, AUTOSTART_MESSAGE, OLD_AUTOSTART_MESSAGE]) + i: int = cast(fdpexpect.fdspawn, tty).expect([self.mcc_prompt, AUTOSTART_MESSAGE, OLD_AUTOSTART_MESSAGE]) if i: - tty.sendline('') # pylint: disable=no-member + cast(fdpexpect.fdspawn, tty).sendline('') # pylint: disable=no-member wait_for_vemsd(self.vemsd_mount, tty, self.mcc_prompt, self.short_delay) try: if image_bundle: @@ -344,20 +363,20 @@ def __call__(self, image_bundle=None, images=None, bootargs=None, connect=True): self._overlay_images(images) os.system('sync') except (IOError, OSError) as e: - msg = 'Could not deploy images to {}; got: {}' + msg: str = 'Could not deploy images to {}; got: {}' raise TargetStableError(msg.format(self.vemsd_mount, e)) - self.target.boot() + cast(BootModule, self.target.boot)() if connect: self.target.connect(timeout=30) - def _deploy_image_bundle(self, bundle): + def _deploy_image_bundle(self, bundle: str) -> None: self.logger.debug('Validating {}'.format(bundle)) validate_image_bundle(bundle) self.logger.debug('Extracting {} into {}...'.format(bundle, self.vemsd_mount)) with tarfile.open(bundle) as tar: safe_extract(tar, self.vemsd_mount) - def _overlay_images(self, images): + def _overlay_images(self, images: Dict[str, str]): for dest, src in images.items(): dest = os.path.join(self.vemsd_mount, dest) self.logger.debug('Copying {} to {}'.format(src, dest)) @@ -366,7 +385,7 @@ def _overlay_images(self, images): # utility functions -def validate_image_bundle(bundle): +def validate_image_bundle(bundle: str) -> None: if not tarfile.is_tarfile(bundle): raise HostError('Image bundle {} does not appear to be a valid TAR file.'.format(bundle)) with tarfile.open(bundle) as tar: @@ -380,9 +399,11 @@ def validate_image_bundle(bundle): raise HostError(msg.format(bundle)) -def wait_for_vemsd(vemsd_mount, tty, mcc_prompt=DEFAULT_MCC_PROMPT, short_delay=1, retries=3): - attempts = 1 + retries - path = os.path.join(vemsd_mount, 'config.txt') +def wait_for_vemsd(vemsd_mount: str, tty: fdpexpect.fdspawn, + mcc_prompt: str = DEFAULT_MCC_PROMPT, short_delay: int = 1, + retries: int = 3) -> None: + attempts: int = 1 + retries + path: str = os.path.join(vemsd_mount, 'config.txt') if os.path.exists(path): return for _ in range(attempts): diff --git a/devlib/platform/__init__.py b/devlib/platform/__init__.py index 205b5c624..5c0e80614 100644 --- a/devlib/platform/__init__.py +++ b/devlib/platform/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,11 @@ # limitations under the License. # -import logging +from typing import Optional, List, TYPE_CHECKING, cast, Dict +from devlib.utils.misc import get_logger +if TYPE_CHECKING: + from devlib.target import Target, AndroidTarget + from devlib.utils.types import caseless_string BIG_CPUS = ['A15', 'A57', 'A72', 'A73'] @@ -22,34 +26,37 @@ class Platform(object): @property - def number_of_clusters(self): + def number_of_clusters(self) -> int: return len(set(self.core_clusters)) def __init__(self, - name=None, - core_names=None, - core_clusters=None, - big_core=None, - model=None, - modules=None, + name: Optional[str] = None, + core_names: Optional[List['caseless_string']] = None, + core_clusters: Optional[List[int]] = None, + big_core: Optional[str] = None, + model: Optional[str] = None, + modules: Optional[List[Dict[str, Dict]]] = None, ): self.name = name self.core_names = core_names or [] self.core_clusters = core_clusters or [] self.big_core = big_core - self.little_core = None + self.little_core: Optional[caseless_string] = None self.model = model self.modules = modules or [] - self.logger = logging.getLogger(self.name) + self.logger = get_logger(self.name) if not self.core_clusters and self.core_names: self._set_core_clusters_from_core_names() - def init_target_connection(self, target): + def init_target_connection(self, target: 'Target') -> None: + """ + do platform specific initialization for the connection + """ # May be ovewritten by subclasses to provide target-specific # connection initialisation. pass - def update_from_target(self, target): + def update_from_target(self, target: 'Target') -> None: if not self.core_names: self.core_names = target.cpuinfo.cpu_names self._set_core_clusters_from_core_names() @@ -63,25 +70,28 @@ def update_from_target(self, target): self.name = self.model self._validate() - def setup(self, target): + def setup(self, target: 'Target') -> None: + """ + Platform specific setup + """ # May be overwritten by subclasses to provide platform-specific # setup procedures. pass - def _set_core_clusters_from_core_names(self): + def _set_core_clusters_from_core_names(self) -> None: self.core_clusters = [] - clusters = [] + clusters: List[str] = [] for cn in self.core_names: if cn not in clusters: clusters.append(cn) self.core_clusters.append(clusters.index(cn)) - def _set_model_from_target(self, target): + def _set_model_from_target(self, target: 'Target'): if target.os == 'android': try: - self.model = target.getprop(prop='ro.product.device') + self.model = cast('AndroidTarget', target).getprop(prop='ro.product.device') except KeyError: - self.model = target.getprop('ro.product.model') + self.model = cast('AndroidTarget', target).getprop('ro.product.model') elif target.file_exists("/proc/device-tree/model"): # There is currently no better way to do this cross platform. # ARM does not have dmidecode @@ -95,21 +105,21 @@ def _set_model_from_target(self, target): except Exception: # pylint: disable=broad-except pass # this is best-effort - def _identify_big_core(self): + def _identify_big_core(self) -> 'caseless_string': for core in self.core_names: if core.upper() in BIG_CPUS: return core big_idx = self.core_clusters.index(max(self.core_clusters)) return self.core_names[big_idx] - def _validate(self): + def _validate(self) -> None: if len(self.core_names) != len(self.core_clusters): raise ValueError('core_names and core_clusters are of different lengths.') if self.big_core and self.number_of_clusters != 2: raise ValueError('attempting to set big_core on non-big.LITTLE device. ' '(number of clusters is not 2)') if self.big_core and self.big_core not in self.core_names: - message = 'Invalid big_core value "{}"; must be in [{}]' + message: str = 'Invalid big_core value "{}"; must be in [{}]' raise ValueError(message.format(self.big_core, ', '.join(set(self.core_names)))) if self.big_core: diff --git a/devlib/platform/arm.py b/devlib/platform/arm.py index 6499ec88e..e97bcf31e 100644 --- a/devlib/platform/arm.py +++ b/devlib/platform/arm.py @@ -1,4 +1,4 @@ -# Copyright 2015-2024 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,38 +25,53 @@ from devlib.utils.csvutil import csvreader, csvwriter from devlib.utils.serial_port import open_serial_connection +# pylint: disable=ungrouped-imports +try: + from pexpect import fdpexpect +# pexpect < 4.0.0 does not have fdpexpect module +except ImportError: + import fdpexpect # type:ignore + +from typing import (cast, TYPE_CHECKING, Match, Optional, + List, Dict, OrderedDict) +from devlib.utils.types import caseless_string +from devlib.utils.annotation_helpers import AdbUserConnectionSettings +from signal import Signals +if TYPE_CHECKING: + from devlib.target import Target + class VersatileExpressPlatform(Platform): - def __init__(self, name, # pylint: disable=too-many-locals + def __init__(self, name: str, # pylint: disable=too-many-locals - core_names=None, - core_clusters=None, - big_core=None, - model=None, - modules=None, + core_names: Optional[List[caseless_string]] = None, + core_clusters: Optional[List[int]] = None, + big_core: Optional[str] = None, + model: Optional[str] = None, + modules: Optional[List[Dict[str, Dict]]] = None, # serial settings - serial_port='/dev/ttyS0', - baudrate=115200, + serial_port: str = '/dev/ttyS0', + baudrate: int = 115200, # VExpress MicroSD mount point - vemsd_mount=None, + vemsd_mount: Optional[str] = None, # supported: dtr, reboottxt - hard_reset_method=None, + hard_reset_method: Optional[str] = None, # supported: uefi, uefi-shell, u-boot, bootmon - bootloader=None, + bootloader: Optional[str] = None, # supported: vemsd - flash_method='vemsd', + flash_method: str = 'vemsd', - image=None, - fdt=None, - initrd=None, - bootargs=None, + image: Optional[str] = None, + fdt: Optional[str] = None, + initrd: Optional[str] = None, + bootargs: Optional[str] = None, - uefi_entry=None, # only used if bootloader is "uefi" - ready_timeout=60, + uefi_entry: Optional[str] = None, # only used if bootloader is "uefi" + ready_timeout: int = 60, ): super(VersatileExpressPlatform, self).__init__(name, core_names, @@ -73,56 +88,56 @@ def __init__(self, name, # pylint: disable=too-many-locals self.bootargs = bootargs self.uefi_entry = uefi_entry self.ready_timeout = ready_timeout - self.bootloader = None - self.hard_reset_method = None + self.bootloader: Optional[str] = None + self.hard_reset_method: Optional[str] = None self._set_bootloader(bootloader) self._set_hard_reset_method(hard_reset_method) self._set_flash_method(flash_method) - def init_target_connection(self, target): + def init_target_connection(self, target: 'Target') -> None: if target.os == 'android': self._init_android_target(target) else: self._init_linux_target(target) - def _init_android_target(self, target): + def _init_android_target(self, target: 'Target') -> None: if target.connection_settings.get('device') is None: addr = self._get_target_ip_address(target) - target.connection_settings['device'] = addr + ':5555' + cast(AdbUserConnectionSettings, target.connection_settings)['device'] = addr + ':5555' - def _init_linux_target(self, target): + def _init_linux_target(self, target: 'Target') -> None: if target.connection_settings.get('host') is None: addr = self._get_target_ip_address(target) target.connection_settings['host'] = addr # pylint: disable=no-member - def _get_target_ip_address(self, target): + def _get_target_ip_address(self, target: 'Target') -> str: with open_serial_connection(port=self.serial_port, baudrate=self.baudrate, timeout=30, - init_dtr=0) as tty: - tty.sendline('su') # this is, apprently, required to query network device - # info by name on recent Juno builds... + init_dtr=False) as tty: + cast(fdpexpect.fdspawn, tty).sendline('su') # this is, apprently, required to query network device + # info by name on recent Juno builds... self.logger.debug('Waiting for the shell prompt.') - tty.expect(target.shell_prompt) + cast(fdpexpect.fdspawn, tty).expect(target.shell_prompt) self.logger.debug('Waiting for IP address...') - wait_start_time = time.time() + wait_start_time: float = time.time() try: while True: - tty.sendline('ip addr list eth0') + cast(fdpexpect.fdspawn, tty).sendline('ip addr list eth0') time.sleep(1) try: - tty.expect(r'inet ([1-9]\d*.\d+.\d+.\d+)', timeout=10) - return tty.match.group(1).decode('utf-8') + cast(fdpexpect.fdspawn, tty).expect(r'inet ([1-9]\d*.\d+.\d+.\d+)', timeout=10) + return cast(Match[bytes], cast(fdpexpect.fdspawn, tty).match).group(1).decode('utf-8') except pexpect.TIMEOUT: pass # We have our own timeout -- see below. if (time.time() - wait_start_time) > self.ready_timeout: raise TargetTransientError('Could not acquire IP address.') finally: - tty.sendline('exit') # exit shell created by "su" call at the start + cast(fdpexpect.fdspawn, tty).sendline('exit') # exit shell created by "su" call at the start - def _set_hard_reset_method(self, hard_reset_method): + def _set_hard_reset_method(self, hard_reset_method: Optional[str]) -> None: if hard_reset_method == 'dtr': self.modules.append({'vexpress-dtr': {'port': self.serial_port, 'baudrate': self.baudrate, @@ -135,7 +150,7 @@ def _set_hard_reset_method(self, hard_reset_method): else: ValueError('Invalid hard_reset_method: {}'.format(hard_reset_method)) - def _set_bootloader(self, bootloader): + def _set_bootloader(self, bootloader: Optional[str]) -> None: self.bootloader = bootloader if self.bootloader == 'uefi': self.modules.append({'vexpress-uefi': {'port': self.serial_port, @@ -152,7 +167,7 @@ def _set_bootloader(self, bootloader): 'bootargs': self.bootargs, }}) elif self.bootloader == 'u-boot': - uboot_env = None + uboot_env: Optional[Dict[str, str]] = None if self.bootargs: uboot_env = {'bootargs': self.bootargs} self.modules.append({'vexpress-u-boot': {'port': self.serial_port, @@ -170,7 +185,7 @@ def _set_bootloader(self, bootloader): else: ValueError('Invalid hard_reset_method: {}'.format(bootloader)) - def _set_flash_method(self, flash_method): + def _set_flash_method(self, flash_method: str) -> None: if flash_method == 'vemsd': self.modules.append({'vexpress-vemsd': {'vemsd_mount': self.vemsd_mount}}) else: @@ -180,10 +195,10 @@ def _set_flash_method(self, flash_method): class Juno(VersatileExpressPlatform): def __init__(self, - vemsd_mount='/media/JUNO', - baudrate=115200, - bootloader='u-boot', - hard_reset_method='dtr', + vemsd_mount: str = '/media/JUNO', + baudrate: int = 115200, + bootloader: str = 'u-boot', + hard_reset_method: str = 'dtr', **kwargs ): super(Juno, self).__init__('juno', @@ -197,10 +212,10 @@ def __init__(self, class TC2(VersatileExpressPlatform): def __init__(self, - vemsd_mount='/media/VEMSD', - baudrate=38400, - bootloader='bootmon', - hard_reset_method='reboottxt', + vemsd_mount: str = '/media/VEMSD', + baudrate: int = 38400, + bootloader: str = 'bootmon', + hard_reset_method: str = 'reboottxt', **kwargs ): super(TC2, self).__init__('tc2', @@ -213,10 +228,10 @@ def __init__(self, class JunoEnergyInstrument(Instrument): - binname = 'readenergy' - mode = CONTINUOUS | INSTANTANEOUS + binname: str = 'readenergy' + mode: int = CONTINUOUS | INSTANTANEOUS - _channels = [ + _channels: List[InstrumentChannel] = [ InstrumentChannel('sys', 'current'), InstrumentChannel('a57', 'current'), InstrumentChannel('a53', 'current'), @@ -235,45 +250,47 @@ class JunoEnergyInstrument(Instrument): InstrumentChannel('gpu', 'energy'), ] - def __init__(self, target): + def __init__(self, target: 'Target'): super(JunoEnergyInstrument, self).__init__(target) - self.on_target_file = None - self.command = None - self.binary = self.target.bin(self.binname) + self.on_target_file: Optional[str] = None + self.command: Optional[str] = None + self.binary: str = self.target.bin(self.binname) for chan in self._channels: - self.channels[chan.name] = chan - self.on_target_file = self.target.tempfile('energy', '.csv') - self.sample_rate_hz = 10 # DEFAULT_PERIOD is 100[ms] in readenergy.c + self.channels[cast(str, chan.name)] = chan + self.on_target_file = cast(Target, self.target).tempfile('energy', '.csv') + self.sample_rate_hz: int = 10 # DEFAULT_PERIOD is 100[ms] in readenergy.c self.command = '{} -o {}'.format(self.binary, self.on_target_file) - self.command2 = '{}'.format(self.binary) + self.command2: str = '{}'.format(self.binary) - def setup(self): # pylint: disable=arguments-differ - self.binary = self.target.install(os.path.join(PACKAGE_BIN_DIRECTORY, - self.target.abi, self.binname)) + def setup(self) -> None: # pylint: disable=arguments-differ + self.binary = cast(Target, self.target).install(os.path.join(PACKAGE_BIN_DIRECTORY, + self.target.abi or '', self.binname)) self.command = '{} -o {}'.format(self.binary, self.on_target_file) self.command2 = '{}'.format(self.binary) - def reset(self, sites=None, kinds=None, channels=None): + def reset(self, sites: Optional[List[str]] = None, + kinds: Optional[List[str]] = None, + channels: Optional[OrderedDict[str, InstrumentChannel]] = None): super(JunoEnergyInstrument, self).reset(sites, kinds, channels) - self.target.killall(self.binname, as_root=True) + cast(Target, self.target).killall(self.binname, as_root=True) - def start(self): - self.target.kick_off(self.command, as_root=True) + def start(self) -> None: + cast(Target, self.target).kick_off(self.command, as_root=True) - def stop(self): - self.target.killall(self.binname, signal='TERM', as_root=True) + def stop(self) -> None: + cast(Target, self.target).killall(self.binname, signal=cast(Signals, 'TERM'), as_root=True) # pylint: disable=arguments-differ - def get_data(self, output_file): - temp_file = tempfile.mktemp() - self.target.pull(self.on_target_file, temp_file) - self.target.remove(self.on_target_file) + def get_data(self, output_file: str) -> MeasurementsCsv: + temp_file: str = tempfile.mktemp() + cast(Target, self.target).pull(self.on_target_file, temp_file) + cast(Target, self.target).remove(self.on_target_file) with csvreader(temp_file) as reader: headings = next(reader) # Figure out which columns from the collected csv we actually want - select_columns = [] + select_columns: List[int] = [] for chan in self.active_channels: try: select_columns.append(headings.index(chan.name)) @@ -281,22 +298,22 @@ def get_data(self, output_file): raise HostError('Channel "{}" is not in {}'.format(chan.name, temp_file)) with csvwriter(output_file) as writer: - write_headings = ['{}_{}'.format(c.site, c.kind) - for c in self.active_channels] + write_headings: List[str] = ['{}_{}'.format(c.site, c.kind) + for c in self.active_channels] writer.writerow(write_headings) for row in reader: - write_row = [row[c] for c in select_columns] + write_row: List[str] = [row[c] for c in select_columns] writer.writerow(write_row) return MeasurementsCsv(output_file, self.active_channels, sample_rate_hz=10) - def take_measurement(self): - result = [] + def take_measurement(self) -> List[Measurement]: + result: List[Measurement] = [] output = self.target.execute(self.command2).split() with csvreader(output) as reader: headings = next(reader) values = next(reader) for chan in self.active_channels: value = values[headings.index(chan.name)] - result.append(Measurement(value, chan)) + result.append(Measurement(cast(float, value), chan)) return result diff --git a/devlib/target.py b/devlib/target.py index 48c0acf04..710d19f18 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +""" +Target module for devlib. +This module defines the Target class and supporting functionality. +""" import atexit import asyncio -from contextlib import contextmanager import io import base64 import functools @@ -24,6 +27,7 @@ import os from operator import itemgetter import re +import sys import time import logging import posixpath @@ -37,55 +41,77 @@ import inspect import itertools from collections import namedtuple, defaultdict -from past.builtins import long -from past.types import basestring from numbers import Number from shlex import quote from weakref import WeakMethod try: from collections.abc import Mapping except ImportError: - from collections import Mapping + from collections import Mapping # type: ignore from enum import Enum from concurrent.futures import ThreadPoolExecutor from devlib.host import LocalConnection, PACKAGE_BIN_DIRECTORY -from devlib.module import get_module, Module +from devlib.module import get_module, Module, HardRestModule, BootModule from devlib.platform import Platform from devlib.exception import (DevlibTransientError, TargetStableError, TargetNotRespondingError, TimeoutError, TargetTransientError, KernelConfigKeyError, - TargetError, HostError, TargetCalledProcessError) + TargetError, HostError, TargetCalledProcessError, + DevlibError) from devlib.utils.ssh import SshConnection -from devlib.utils.android import AdbConnection, AndroidProperties, LogcatMonitor, adb_command, INTENT_FLAGS -from devlib.utils.misc import memoized, isiterable, convert_new_lines, groupby_value -from devlib.utils.misc import commonprefix, merge_lists -from devlib.utils.misc import ABI_MAP, get_cpu_name, ranges_to_list -from devlib.utils.misc import batch_contextmanager, tls_property, _BoundTLSProperty, nullcontext -from devlib.utils.misc import safe_extract -from devlib.utils.types import integer, boolean, bitmask, identifier, caseless_string, bytes_regex +from devlib.utils.android import (AdbConnection, AndroidProperties, + LogcatMonitor, adb_command, INTENT_FLAGS) +from devlib.utils.misc import (memoized, isiterable, convert_new_lines, + groupby_value, commonprefix, ABI_MAP, get_cpu_name, + ranges_to_list, batch_contextmanager, tls_property, + _BoundTLSProperty, nullcontext, safe_extract, get_logger) +from devlib.utils.types import (integer, boolean, bitmask, identifier, + caseless_string, bytes_regex) import devlib.utils.asyn as asyn - - -FSTAB_ENTRY_REGEX = re.compile(r'(\S+) on (.+) type (\S+) \((\S+)\)') -ANDROID_SCREEN_STATE_REGEX = re.compile('(?:mPowerState|mScreenOn|mWakefulness|Display Power: state)=([0-9]+|true|false|ON|OFF|DOZE|Dozing|Asleep|Awake)', - re.IGNORECASE) -ANDROID_SCREEN_RESOLUTION_REGEX = re.compile(r'cur=(?P\d+)x(?P\d+)') -ANDROID_SCREEN_ROTATION_REGEX = re.compile(r'orientation=(?P[0-3])') -DEFAULT_SHELL_PROMPT = re.compile(r'^.*(shell|root|juno)@?.*:[/~]\S* *[#$] ', - re.MULTILINE) -KVERSION_REGEX = re.compile( +from devlib.utils.annotation_helpers import (SshUserConnectionSettings, UserConnectionSettings, + AdbUserConnectionSettings, SupportedConnections, + SubprocessCommand, BackgroundCommand) +from typing import (List, Set, Dict, Union, Optional, Callable, TypeVar, + Any, cast, TYPE_CHECKING, Type, Pattern, + Tuple, Iterator, AsyncContextManager, Iterable, + Mapping as Maptype) +from collections.abc import AsyncGenerator +from types import ModuleType +from typing_extensions import Literal +import signal +if TYPE_CHECKING: + from devlib.connection import ConnectionBase + from devlib.utils.misc import InitCheckpointMeta + from devlib.utils.asyn import AsyncManager, _AsyncPolymorphicFunction + from asyncio import AbstractEventLoop + from contextlib import _GeneratorContextManager + from re import Match + from xml.dom.minidom import Document + + +FSTAB_ENTRY_REGEX: Pattern[str] = re.compile(r'(\S+) on (.+) type (\S+) \((\S+)\)') +ANDROID_SCREEN_STATE_REGEX: Pattern[str] = re.compile('(?:mPowerState|mScreenOn|mWakefulness|Display Power: state)=([0-9]+|true|false|ON|OFF|DOZE|Dozing|Asleep|Awake)', + re.IGNORECASE) +ANDROID_SCREEN_RESOLUTION_REGEX: Pattern[str] = re.compile(r'cur=(?P\d+)x(?P\d+)') +ANDROID_SCREEN_ROTATION_REGEX: Pattern[str] = re.compile(r'orientation=(?P[0-3])') +DEFAULT_SHELL_PROMPT: Pattern[str] = re.compile(r'^.*(shell|root|juno)@?.*:[/~]\S* *[#$] ', + re.MULTILINE) +KVERSION_REGEX: Pattern[str] = re.compile( r'(?P\d+)(\.(?P\d+)(\.(?P\d+))?(-rc(?P\d+))?)?(-android(?P[0-9]+))?(-(?P\d+)-g(?P[0-9a-fA-F]{7,}))?(-ab(?P[0-9]+))?' ) -GOOGLE_DNS_SERVER_ADDRESS = '8.8.8.8' +GOOGLE_DNS_SERVER_ADDRESS: str = '8.8.8.8' installed_package_info = namedtuple('installed_package_info', 'apk_path package') +T = TypeVar('T', bound=Callable[..., Any]) -def call_conn(f): + +# FIXME - need to annotate to indicate the self argument needs to have a conn object of ConnectionBase type. +def call_conn(f: T) -> T: """ Decorator to be used on all :class:`devlib.target.Target` methods that directly use a method of ``self.conn``. @@ -96,13 +122,18 @@ def call_conn(f): ``__del__``, which could be executed by the garbage collector, interrupting another call to a method of the connection instance. + :param f: Method to decorate. + + :returns: The wrapped method that automatically creates and releases + a new connection if reentered. + .. note:: This decorator could be applied directly to all methods with a metaclass or ``__init_subclass__`` but it could create issues when passing target methods as callbacks to connections' methods. """ @functools.wraps(f) - def wrapper(self, *args, **kwargs): + def wrapper(self, *args: Any, **kwargs: Any) -> Any: conn = self.conn reentered = conn.is_in_use disconnect = False @@ -132,97 +163,365 @@ def wrapper(self, *args, **kwargs): with self._lock: self._unused_conns.add(conn) - return wrapper + return cast(T, wrapper) class Target(object): + """ + An abstract base class defining the interface for a devlib target device. + + :param connection_settings: Connection parameters for the target + (e.g., SSH, ADB) in a dictionary. + :param platform: A platform object describing architecture, ABI, kernel, + etc. If ``None``, platform info may be inferred or left unspecified. + :param working_directory: A writable directory on the target for devlib's + temporary files or scripts. If ``None``, a default path is used. + :param executables_directory: A directory on the target for storing + executables installed by devlib. If ``None``, a default path may be used. + :param connect: If ``True``, attempt to connect to the device immediately, + else call :meth:`connect` manually. + :param modules: Dict mapping module names to their parameters. Additional + devlib modules to load on initialization. + :param load_default_modules: If ``True``, load the modules specified in + :attr:`default_modules`. + :param shell_prompt: Compiled regex matching the target’s shell prompt. + :param conn_cls: A reference to the Connection class to be used. + :param is_container: If ``True``, indicates the target is a container + rather than a physical or virtual machine. + :param max_async: Number of asynchronous operations supported. Affects the + creation of parallel connections. + + :raises: Various :class:`devlib.exception` types if connection fails. + + .. note:: + Subclasses must implement device-specific methods (e.g., for Android vs. Linux or + specialized boards). The default implementation here may be incomplete. + """ + path: Optional[ModuleType] = None + os: Optional[str] = None + system_id: Optional[str] = None - path = None - os = None - system_id = None + default_modules: List[Type[Module]] = [] - default_modules = [] + def __init__(self, + connection_settings: Optional[UserConnectionSettings] = None, + platform: Optional[Platform] = None, + working_directory: Optional[str] = None, + executables_directory: Optional[str] = None, + connect: bool = True, + modules: Optional[Dict[str, Dict[str, Type[Module]]]] = None, + load_default_modules: bool = True, + shell_prompt: Pattern[str] = DEFAULT_SHELL_PROMPT, + conn_cls: Optional['InitCheckpointMeta'] = None, + is_container: bool = False, + max_async: int = 50, + tmp_directory: Optional[str] = None, + ): + """ + Initialize a new Target instance and optionally connect to it. + """ + self._lock = threading.RLock() + self._async_pool: Optional[ThreadPoolExecutor] = None + self._async_pool_size: Optional[int] = None + self._unused_conns: Set[ConnectionBase] = set() + + self._is_rooted: Optional[bool] = None + self.connection_settings: UserConnectionSettings = connection_settings or {} + # Set self.platform: either it's given directly (by platform argument) + # or it's given in the connection_settings argument + # If neither, create default Platform() + if platform is None: + self.platform = self.connection_settings.get('platform', Platform()) + else: + self.platform = platform + # Check if the user hasn't given two different platforms + if connection_settings and ('platform' in self.connection_settings) and ('platform' in connection_settings): + if connection_settings['platform'] is not platform: + raise TargetStableError('Platform specified in connection_settings ' + '({}) differs from that directly passed ' + '({})!)' + .format(connection_settings['platform'], + self.platform)) + self.connection_settings['platform'] = self.platform + self.working_directory = working_directory + self.executables_directory = executables_directory + self.tmp_directory = tmp_directory + self.load_default_modules = load_default_modules + self.shell_prompt: Pattern[str] = bytes_regex(shell_prompt) + self.conn_cls = conn_cls + self.is_container = is_container + self.logger = get_logger(self.__class__.__name__) + self._installed_binaries: Dict[str, str] = {} + self._installed_modules: Dict[str, Module] = {} + self._cache: Dict = {} + self._shutils: Optional[str] = None + self._max_async = max_async + self.busybox: Optional[str] = None + + def normalize_mod_spec(spec) -> Tuple[str, Dict[str, Type[Module]]]: + if isinstance(spec, str): + return (spec, {}) + else: + [(name, params)] = spec.items() + return (name, params) + + normalized_modules: List[Tuple[str, Dict[str, Type[Module]]]] = sorted( + map( + normalize_mod_spec, + itertools.chain( + self.default_modules if load_default_modules else [], + modules or [], + self.platform.modules or [], + ) + ), + key=itemgetter(0), + ) + + # Ensure that we did not ask for the same module but different + # configurations. Empty configurations are ignored, so any + # user-provided conf will win against an empty conf. + def elect(name: str, specs: List[Tuple[str, Dict[str, Type[Module]]]]) -> Tuple[str, Dict[str, Type[Module]]]: + specs = list(specs) + + confs = set( + tuple(sorted(params.items())) + for _, params in specs + if params + ) + if len(confs) > 1: + raise ValueError(f'Attempted to load the module "{name}" with multiple different configuration') + else: + if any( + params is None + for _, params in specs + ): + params = {} + else: + params = dict(confs.pop()) if confs else {} + + return (name, params) + + modules = dict(itertools.starmap( + elect, + itertools.groupby(normalized_modules, key=itemgetter(0)) + )) + + def get_kind(name: str) -> str: + return get_module(name).kind or '' + + def kind_conflict(kind: str, names: List[str]): + if kind: + raise ValueError(f'Cannot enable multiple modules sharing the same kind "{kind}": {sorted(names)}') + + list(itertools.starmap( + kind_conflict, + itertools.groupby( + sorted( + modules.keys(), + key=get_kind + ), + key=get_kind + ) + )) + self._modules = modules + + atexit.register( + WeakMethod(self.disconnect, atexit.unregister) + ) + + self._update_modules('early') + if connect: + self.connect(max_async=max_async) @property - def core_names(self): - return self.platform.core_names + def core_names(self) -> Union[List[caseless_string], List[str]]: + """ + A list of CPU core names in the order they appear + registered with the OS. If they are not specified, + they will be queried at run time. + + :return: CPU core names in order (e.g. ["A53", "A53", "A72", "A72"]). + """ + if self.platform: + return self.platform.core_names + raise ValueError("No Platform set for this target, cannot access core_names") @property - def core_clusters(self): - return self.platform.core_clusters + def core_clusters(self) -> List[int]: + """ + A list with cluster ids of each core (starting with + 0). If this is not specified, clusters will be + inferred from core names (cores with the same name are + assumed to be in a cluster). + + :return: A list of integer cluster IDs for each core. + """ + if self.platform: + return self.platform.core_clusters + raise ValueError("No Platform set for this target cannot access core_clusters") @property - def big_core(self): - return self.platform.big_core + def big_core(self) -> Optional[str]: + """ + The name of the big core in a big.LITTLE system. If this is + not specified it will be inferred (on systems with exactly + two clusters). + + :return: Big core name, or None if not defined. + """ + if self.platform: + return self.platform.big_core + raise ValueError("No Platform set for this target cannot access big_core") @property - def little_core(self): - return self.platform.little_core + def little_core(self) -> Optional[str]: + """ + The name of the little core in a big.LITTLE system. If this is + not specified it will be inferred (on systems with exactly + two clusters). + + :return: Little core name, or None if not defined. + """ + if self.platform: + return self.platform.little_core + raise ValueError("No Platform set for this target cannot access little_core") @property - def is_connected(self): + def is_connected(self) -> bool: + """ + Indicates whether there is an active connection to the target. + + :return: True if connected, else False. + """ return self.conn is not None @property - def connected_as_root(self): - return self.conn and self.conn.connected_as_root + def connected_as_root(self) -> Optional[bool]: + """ + Indicates whether the connection user on the target is root (uid=0). + + :return: True if root, False otherwise, or None if unknown. + """ + if self.conn: + if self.conn.connected_as_root: + return True + return False @property - def is_rooted(self): + def is_rooted(self) -> Optional[bool]: + """ + Indicates whether superuser privileges (root or sudo) are available. + + :return: True if superuser privileges are accessible, False if not, + or None if undetermined. + """ if self._is_rooted is None: try: self.execute('ls /', timeout=5, as_root=True) self._is_rooted = True - except(TargetError, TimeoutError): + except (TargetError, TimeoutError): self._is_rooted = False return self._is_rooted or self.connected_as_root @property @memoized - def needs_su(self): + def needs_su(self) -> Optional[bool]: + """ + Whether the current user must escalate privileges to run root commands. + + :return: True if the device is rooted but not connected as root. + """ return not self.connected_as_root and self.is_rooted @property @memoized - def kernel_version(self): - return KernelVersion(self.execute('{} uname -r -v'.format(quote(self.busybox))).strip()) + def kernel_version(self) -> 'KernelVersion': + """ + The kernel version from ``uname -r -v``, wrapped in a KernelVersion object. + + :raises ValueError: If busybox is unavailable for executing the uname command. + :return: Kernel version details. + """ + if self.busybox: + return KernelVersion(self.execute('{} uname -r -v'.format(quote(self.busybox))).strip()) + raise ValueError("busybox not set. Cannot get kernel version") @property - def hostid(self): + def hostid(self) -> int: + """ + A numeric ID representing the system's host identity. + + :return: The hostid as an integer (parsed from hex). + """ return int(self.execute('{} hostid'.format(self.busybox)).strip(), 16) @property - def hostname(self): + def hostname(self) -> str: + """ + System hostname from ``hostname`` or ``uname -n``. + + :return: Hostname of the target. + """ return self.execute('{} hostname'.format(self.busybox)).strip() @property - def os_version(self): # pylint: disable=no-self-use + def os_version(self) -> Dict[str, str]: # pylint: disable=no-self-use + """ + A mapping of OS version info. Empty by default; child classes may override. + + :return: OS version details. + """ return {} @property - def model(self): + def model(self) -> Optional[str]: + """ + Hardware model name, if any. + + :return: Model name, or None if not defined. + """ return self.platform.model @property - def abi(self): # pylint: disable=no-self-use + def abi(self) -> Optional[str]: # pylint: disable=no-self-use + """ + The primary application binary interface (ABI) of this target. + + :return: ABI name (e.g. "armeabi-v7a"), or None if unknown. + """ return None @property - def supported_abi(self): + def supported_abi(self) -> List[Optional[str]]: + """ + A list of all supported ABIs. + + :return: List of ABI strings. + """ return [self.abi] @property @memoized - def cpuinfo(self): + def cpuinfo(self) -> 'Cpuinfo': + """ + Parsed data from ``/proc/cpuinfo``. + + :return: A :class:`Cpuinfo` instance with CPU details. + """ return Cpuinfo(self.execute('cat /proc/cpuinfo')) @property @memoized - def number_of_cpus(self): - num_cpus = 0 + def number_of_cpus(self) -> int: + """ + Count of CPU cores, determined by listing ``/sys/devices/system/cpu/cpu*``. + + :return: Number of CPU cores. + """ + num_cpus: int = 0 corere = re.compile(r'^\s*cpu\d+\s*$') - output = self.execute('ls /sys/devices/system/cpu', as_root=self.is_rooted) + output: str = self.execute('ls /sys/devices/system/cpu', as_root=self.is_rooted) for entry in output.split(): if corere.match(entry): num_cpus += 1 @@ -230,15 +529,23 @@ def number_of_cpus(self): @property @memoized - def number_of_nodes(self): - cmd = 'cd /sys/devices/system/node && {busybox} find . -maxdepth 1'.format(busybox=quote(self.busybox)) + def number_of_nodes(self) -> int: + """ + Number of NUMA nodes detected by enumerating ``/sys/devices/system/node``. + + :return: NUMA node count, or 1 if unavailable. + """ + if self.busybox: + cmd = 'cd /sys/devices/system/node && {busybox} find . -maxdepth 1'.format(busybox=quote(self.busybox)) + else: + raise ValueError('busybox not set. cannot form cmd') try: - output = self.execute(cmd, as_root=self.is_rooted) + output: str = self.execute(cmd, as_root=self.is_rooted) except TargetStableError: return 1 else: nodere = re.compile(r'^\./node\d+\s*$') - num_nodes = 0 + num_nodes: int = 0 for entry in output.splitlines(): if nodere.match(entry): num_nodes += 1 @@ -246,17 +553,29 @@ def number_of_nodes(self): @property @memoized - def list_nodes_cpus(self): - nodes_cpus = [] + def list_nodes_cpus(self) -> List[int]: + """ + Aggregated list of CPU IDs across all NUMA nodes. + + :return: A list of CPU IDs from each detected node. + """ + nodes_cpus: List[int] = [] for node in range(self.number_of_nodes): - path = self.path.join('/sys/devices/system/node/node{}/cpulist'.format(node)) - output = self.read_value(path) - nodes_cpus.append(ranges_to_list(output)) + if self.path: + path: str = self.path.join('/sys/devices/system/node/node{}/cpulist'.format(node)) + output: str = self.read_value(path) + if output: + nodes_cpus.extend(ranges_to_list(output)) return nodes_cpus @property @memoized - def config(self): + def config(self) -> 'KernelConfig': + """ + Parsed kernel config from ``/proc/config.gz`` or ``/boot/config-*``. + + :return: A :class:`KernelConfig` instance. + """ try: return KernelConfig(self.execute('zcat /proc/config.gz')) except TargetStableError: @@ -269,29 +588,61 @@ def config(self): @property @memoized - def user(self): + def user(self) -> str: + """ + The username for the active shell on the target. + + :return: Username (e.g., "root" or "shell"). + """ return self.getenv('USER') @property @memoized - def page_size_kb(self): + def page_size_kb(self) -> int: + """ + Page size in kilobytes, derived from ``/proc/self/smaps``. + + :return: Page size in KiB, or 0 if unknown. + """ cmd = "cat /proc/self/smaps | {0} grep KernelPageSize | {0} head -n 1 | {0} awk '{{ print $2 }}'" return int(self.execute(cmd.format(self.busybox)) or 0) @property - def shutils(self): + def shutils(self) -> Optional[str]: + """ + Path to shell utilities (if installed by devlib). Internal usage. + + :return: The path or None if uninitialized. + """ if self._shutils is None: self._setup_scripts() return self._shutils - def is_running(self, comm): + def is_running(self, comm: str) -> bool: + """ + Check if a process with the specified name/command is running on the target. + + :param comm: The process name to search for. + :return: True if a matching process is found, else False. + """ cmd_ps = f'''{self.busybox} ps -A -T -o stat,comm''' cmd_awk = f'''{self.busybox} awk 'BEGIN{{found=0}} {{state=$1; $1=""; if ($state != "Z" && $0 == " {comm}") {{found=1}}}} END {{print found}}' ''' - result = self.execute(f"{cmd_ps} | {cmd_awk}") + result: str = self.execute(f"{cmd_ps} | {cmd_awk}") return bool(int(result)) @tls_property - def _conn(self): + def _conn(self) -> 'ConnectionBase': + """ + The underlying connection object. This will be ``None`` if an active + connection does not exist (e.g. if ``connect=False`` as passed on + initialization and :meth:`connect()` has not been called). + + :returns: The thread-local :class:`ConnectionBase` instance. + + .. note:: a :class:`~devlib.target.Target` will automatically create a + connection per thread. This will always be set to the connection + for the current thread. + """ try: with self._lock: return self._unused_conns.pop() @@ -299,148 +650,30 @@ def _conn(self): return self.get_connection() # Add a basic property that does not require calling to get the value - conn = _conn.basic_property + conn: SupportedConnections = cast(SupportedConnections, _conn.basic_property) @tls_property - def _async_manager(self): + def _async_manager(self) -> 'AsyncManager': + """ + Thread-local property that holds an async manager for concurrency tasks. + + :return: Async manager instance for the current thread. + """ return asyn.AsyncManager() # Add a basic property that does not require calling to get the value - async_manager = _async_manager.basic_property + async_manager: 'AsyncManager' = cast('AsyncManager', _async_manager.basic_property) - def __init__(self, - connection_settings=None, - platform=None, - working_directory=None, - executables_directory=None, - connect=True, - modules=None, - load_default_modules=True, - shell_prompt=DEFAULT_SHELL_PROMPT, - conn_cls=None, - is_container=False, - max_async=50, - tmp_directory=None, - ): - - self._lock = threading.RLock() - self._async_pool = None - self._async_pool_size = None - self._unused_conns = set() - - self._is_rooted = None - self.connection_settings = connection_settings or {} - # Set self.platform: either it's given directly (by platform argument) - # or it's given in the connection_settings argument - # If neither, create default Platform() - if platform is None: - self.platform = self.connection_settings.get('platform', Platform()) - else: - self.platform = platform - # Check if the user hasn't given two different platforms - if 'platform' in self.connection_settings: - if connection_settings['platform'] is not platform: - raise TargetStableError('Platform specified in connection_settings ' - '({}) differs from that directly passed ' - '({})!)' - .format(connection_settings['platform'], - self.platform)) - self.connection_settings['platform'] = self.platform - self.working_directory = working_directory - self.executables_directory = executables_directory - self.tmp_directory = tmp_directory - self.load_default_modules = load_default_modules - self.shell_prompt = bytes_regex(shell_prompt) - self.conn_cls = conn_cls - self.is_container = is_container - self.logger = logging.getLogger(self.__class__.__name__) - self._installed_binaries = {} - self._installed_modules = {} - self._cache = {} - self._shutils = None - self._max_async = max_async - self.busybox = None - - def normalize_mod_spec(spec): - if isinstance(spec, str): - return (spec, {}) - else: - [(name, params)] = spec.items() - return (name, params) - - modules = sorted( - map( - normalize_mod_spec, - itertools.chain( - self.default_modules if load_default_modules else [], - modules or [], - self.platform.modules or [], - ) - ), - key=itemgetter(0), - ) - - # Ensure that we did not ask for the same module but different - # configurations. Empty configurations are ignored, so any - # user-provided conf will win against an empty conf. - def elect(name, specs): - specs = list(specs) - - confs = set( - tuple(sorted(params.items())) - for _, params in specs - if params - ) - if len(confs) > 1: - raise ValueError(f'Attempted to load the module "{name}" with multiple different configuration') - else: - if any( - params is None - for _, params in specs - ): - params = None - else: - params = dict(confs.pop()) if confs else {} - - return (name, params) - - modules = dict(itertools.starmap( - elect, - itertools.groupby(modules, key=itemgetter(0)) - )) - - def get_kind(name): - return get_module(name).kind or '' - - def kind_conflict(kind, names): - if kind: - raise ValueError(f'Cannot enable multiple modules sharing the same kind "{kind}": {sorted(names)}') - - list(itertools.starmap( - kind_conflict, - itertools.groupby( - sorted( - modules.keys(), - key=get_kind - ), - key=get_kind - ) - )) - self._modules = modules - - atexit.register( - WeakMethod(self.disconnect, atexit.unregister) - ) - - self._update_modules('early') - if connect: - self.connect(max_async=max_async) + def __getstate__(self) -> Dict[str, Any]: + """ + For pickling: exclude thread-local objects from the state. - def __getstate__(self): + :return: A dictionary representing the object's state. + """ # tls_property will recreate the underlying value automatically upon # access and is typically used for dynamic content that cannot be # pickled or should not transmitted to another thread. - ignored = { + ignored: set[str] = { k for k, v in inspect.getmembers(self.__class__) if isinstance(v, _BoundTLSProperty) @@ -456,7 +689,12 @@ def __getstate__(self): if k not in ignored } - def __setstate__(self, dct): + def __setstate__(self, dct: Dict[str, Any]) -> None: + """ + Restores the object's state after unpickling, reinitializing ephemeral objects. + + :param dct: The saved state dictionary. + """ self.__dict__ = dct pool_size = self._async_pool_size if pool_size is None: @@ -469,7 +707,18 @@ def __setstate__(self, dct): # connection and initialization @asyn.asyncf - async def connect(self, timeout=None, check_boot_completed=True, max_async=None): + async def connect(self, timeout: Optional[int] = None, + check_boot_completed: Optional[bool] = True, + max_async: Optional[int] = None) -> None: + """ + Connect to the target (e.g., via SSH or another transport). + + :param timeout: Timeout (in seconds) for connecting. + :param check_boot_completed: If ``True``, verify the target has booted. + :param max_async: The number of parallel async connections to allow. + + :raises TargetError: If the device fails to connect within the specified time. + """ self.platform.init_target_connection(self) # Forcefully set the thread-local value for the connection, with the # timeout we want @@ -483,12 +732,12 @@ async def connect(self, timeout=None, check_boot_completed=True, max_async=None) self.executables_directory = self.path.join( self.working_directory, 'bin' - ) + ) if self.path else '' for path in (self.working_directory, self.executables_directory): self.makedirs(path) - self.busybox = self.install(os.path.join(PACKAGE_BIN_DIRECTORY, self.abi, 'busybox'), timeout=30) + self.busybox = self.install(os.path.join(PACKAGE_BIN_DIRECTORY, self.abi or '', 'busybox'), timeout=30) self.conn.busybox = self.busybox # If neither the mktemp call nor _resolve_paths() managed to get a @@ -501,7 +750,7 @@ async def connect(self, timeout=None, check_boot_completed=True, max_async=None) # Some Android platforms don't have a working mktemp unless # TMPDIR is set, so we let AndroidTarget._resolve_paths() deal # with finding a suitable location. - tmp = self.path.join(self.working_directory, 'tmp') + tmp = self.path.join(self.working_directory, 'tmp') if self.path else '' else: tmp = tmp.strip() self.tmp_directory = tmp @@ -511,10 +760,19 @@ async def connect(self, timeout=None, check_boot_completed=True, max_async=None) self.platform.update_from_target(self) self._update_modules('connected') - def _detect_max_async(self, max_async): + def _detect_max_async(self, max_async: int) -> None: + """ + Attempt to detect the maximum number of parallel asynchronous + commands the target can handle by opening multiple connections. + + :param max_async: Upper bound for parallel async connections. + """ self.logger.debug('Detecting max number of async commands ...') - def make_conn(_): + def make_conn(_) -> Optional[SupportedConnections]: + """ + create a connection to target to execute a command + """ try: conn = self.get_connection() except Exception: @@ -546,23 +804,27 @@ def make_conn(_): finally: logging.disable(logging.NOTSET) - conns = {conn for conn in conns if conn is not None} + resultconns = {conn for conn in conns if conn is not None} # Keep the connection so it can be reused by future threads - self._unused_conns.update(conns) - max_conns = len(conns) + self._unused_conns.update(resultconns) + max_conns = len(resultconns) self.logger.debug(f'Detected max number of async commands: {max_conns}') self._async_pool_size = max_conns self._async_pool = ThreadPoolExecutor(max_conns) @asyn.asyncf - async def check_connection(self): + async def check_connection(self) -> None: """ - Check that the connection works without obvious issues. + Perform a quick command to verify the target's shell is responsive. + + :raises TargetStableError: If the shell is present but not functioning + correctly (e.g., output on stderr). + :raises TargetNotRespondingError: If the target is unresponsive. """ - async def check(**kwargs): - out = await self.execute.asyn('true', **kwargs) + async def check(*, as_root: Union[Literal[False], Literal[True], str] = False) -> None: + out = await self.execute.asyn('true', as_root=as_root) if out: raise TargetStableError('The shell seems to not be functional and adds content to stderr: {!r}'.format(out)) @@ -573,9 +835,13 @@ async def check(**kwargs): if self.is_rooted: await check(as_root=True) - def disconnect(self): + def disconnect(self) -> None: + """ + Close all active connections to the target and terminate any + connection threads or asynchronous operations. + """ with self._lock: - thread_conns = self._conn.get_all_values() + thread_conns: Set[SupportedConnections] = self._conn.get_all_values() # Now that we have all the connection objects, we simply reset the # TLS property so that the connections we obtained will not be # reused anywhere. @@ -593,30 +859,85 @@ def disconnect(self): pool.__exit__(None, None, None) def __enter__(self): + """ + Context manager entrypoint. Returns self. + """ return self def __exit__(self, *args, **kwargs): + """ + Context manager exitpoint. Automatically disconnects from the device. + """ self.disconnect() async def __aenter__(self): + """ + Async context manager entry. + """ return self.__enter__() async def __aexit__(self, *args, **kwargs): + """ + Async context manager exit. + """ return self.__exit__(*args, **kwargs) - def get_connection(self, timeout=None): + def get_connection(self, timeout: Optional[int] = None) -> SupportedConnections: + """ + Get an additional connection to the target. A connection can be used to + execute one blocking command at time. This will return a connection that can + be used to interact with a target in parallel while a blocking operation is + being executed. + + This should *not* be used to establish an initial connection; use + :meth:`connect()` instead. + + :param timeout: Timeout (in seconds) for establishing the connection. + :returns: A new connection object to be used by the caller. + :raises ValueError: If no connection class (`conn_cls`) is set. + + .. note:: :class:`~devlib.target.Target` will automatically create a connection + per thread, so you don't normally need to use this explicitly in + threaded code. This is generally useful if you want to perform a + blocking operation (e.g. using :class:`background()`) while at the same + time doing something else in the same host-side thread. + """ if self.conn_cls is None: raise ValueError('Connection class not specified on Target creation.') - conn = self.conn_cls(timeout=timeout, **self.connection_settings) # pylint: disable=not-callable + conn: SupportedConnections = self.conn_cls(timeout=timeout, **self.connection_settings) # pylint: disable=not-callable # This allows forwarding the detected busybox for connections created in new threads. conn.busybox = self.busybox return conn - def wait_boot_complete(self, timeout=10): + def wait_boot_complete(self, timeout: Optional[int] = 10) -> None: + """ + Wait for the device to boot. Must be overridden by derived classes + if the device needs a specific boot-completion check. + + :param timeout: How long to wait for the device to finish booting. + :raises NotImplementedError: If not implemented in child classes. + """ raise NotImplementedError() @asyn.asyncf - async def setup(self, executables=None): + async def setup(self, executables: Optional[List[str]] = None) -> None: + """ + This will perform an initial one-time set up of a device for devlib + interaction. This involves deployment of tools relied on the + :class:`~devlib.target.Target`, creation of working locations on the device, + etc. + + Usually, it is enough to call this method once per new device, as its effects + will persist across reboots. However, it is safe to call this method multiple + times. It may therefore be a good practice to always call it once at the + beginning of a script to ensure that subsequent interactions will succeed. + + Optionally, this may also be used to deploy additional tools to the device + by specifying a list of binaries to install in the ``executables`` parameter. + + :param executables: Optional list of host-side binaries to install + on the target during setup. + """ await self._setup_scripts.asyn() for host_exe in (executables or []): # pylint: disable=superfluous-parens @@ -628,11 +949,23 @@ async def setup(self, executables=None): # Initialize modules which requires Busybox (e.g. shutil dependent tasks) self._update_modules('setup') - def reboot(self, hard=False, connect=True, timeout=180): + def reboot(self, hard: bool = False, connect: bool = True, timeout: int = 180) -> None: + """ + Reboot the target. Optionally performs a hard reset if supported + by a :class:`HardRestModule`. + + :param hard: If ``True``, use a hard reset. + :param connect: If ``True``, reconnect after reboot finishes. + :param timeout: Timeout in seconds for reconnection. + + :raises TargetStableError: If hard reset is requested but not supported. + :raises TargetTransientError: If the target is not currently connected + and a soft reset is requested. + """ if hard: if not self.has('hard_reset'): raise TargetStableError('Hard reset not supported for this target.') - self.hard_reset() # pylint: disable=no-member + cast(HardRestModule, self.hard_reset)() # pylint: disable=no-member else: if not self.is_connected: message = 'Cannot reboot target because it is disconnected. ' +\ @@ -648,7 +981,7 @@ def reboot(self, hard=False, connect=True, timeout=180): time.sleep(reset_delay) timeout = max(timeout - reset_delay, 10) if self.has('boot'): - self.boot() # pylint: disable=no-member + cast(BootModule, self.boot)() # pylint: disable=no-member self.conn.connected_as_root = None if connect: self.connect(timeout=timeout) @@ -656,7 +989,7 @@ def reboot(self, hard=False, connect=True, timeout=180): # file transfer @asyn.asynccontextmanager - async def _xfer_cache_path(self, name): + async def _xfer_cache_path(self, name: str) -> AsyncGenerator[str, None]: """ Context manager to provide a unique path in the transfer cache with the basename of the given name. @@ -665,10 +998,12 @@ async def _xfer_cache_path(self, name): name = os.path.normpath(name) name = os.path.basename(name) async with self.make_temp() as tmp: - yield self.path.join(tmp, name) + if self.path: + yield self.path.join(tmp, name) @asyn.asyncf - async def _prepare_xfer(self, action, sources, dest, pattern=None, as_root=False): + async def _prepare_xfer(self, action: str, sources: List[str], dest: str, + pattern: Optional[str] = None, as_root: bool = False) -> Dict[Tuple[str, ...], str]: """ Check the sanity of sources and destination and prepare the ground for transfering multiple sources. @@ -688,16 +1023,17 @@ async def wrapper(path): return wrapper - _target_cache = {} - async def target_paths_kind(paths, as_root=False): - def process(x): + _target_cache: Dict[str, Optional[str]] = {} + + async def target_paths_kind(paths: List[str], as_root: bool = False) -> List[Optional[str]]: + def process(x: str) -> Optional[str]: x = x.strip() if x == 'notexist': return None else: return x - _paths = [ + _paths: List[str] = [ path for path in paths if path not in _target_cache @@ -717,9 +1053,10 @@ def process(x): for path in paths ] - _host_cache = {} - async def host_paths_kind(paths, as_root=False): - def path_kind(path): + _host_cache: Dict[str, Optional[str]] = {} + + async def host_paths_kind(paths: List[str], as_root: bool = False) -> List[Optional[str]]: + def path_kind(path: str) -> Optional[str]: if os.path.isdir(path): return 'dir' elif os.path.exists(path): @@ -742,39 +1079,40 @@ def path_kind(path): # use SFTP for these operations, which should be cheaper than # Target.execute() if action == 'push': - src_excep = HostError + src_excep: Type[DevlibError] = HostError src_path_kind = host_paths_kind _dst_mkdir = once(self.makedirs.asyn) - dst_path_join = self.path.join + if self.path: + dst_path_join = self.path.join dst_paths_kind = target_paths_kind @once - async def dst_remove_file(path): + async def dst_remove_file(path: str): # type:ignore return await self.remove.asyn(path, as_root=as_root) elif action == 'pull': src_excep = TargetStableError src_path_kind = target_paths_kind @once - async def _dst_mkdir(path): + async def _dst_mkdir(path: str): return os.makedirs(path, exist_ok=True) dst_path_join = os.path.join dst_paths_kind = host_paths_kind @once - async def dst_remove_file(path): + async def dst_remove_file(path: str): return os.remove(path) else: raise ValueError('Unknown action "{}"'.format(action)) # Handle the case where path is None - async def dst_mkdir(path): + async def dst_mkdir(path: Optional[str]) -> None: if path: await _dst_mkdir(path) - async def rewrite_dst(src, dst): - new_dst = dst_path_join(dst, os.path.basename(src)) + async def rewrite_dst(src: str, dst: str) -> str: + new_dst: str = dst_path_join(dst, os.path.basename(src)) src_kind, = await src_path_kind([src], as_root) # Batch both checks to avoid a costly extra execute() @@ -832,14 +1170,38 @@ async def f(src): @asyn.asyncf @call_conn - async def push(self, source, dest, as_root=False, timeout=None, globbing=False): # pylint: disable=arguments-differ + async def push(self, source: str, dest: str, as_root: bool = False, + timeout: Optional[int] = None, globbing: bool = False) -> None: # pylint: disable=arguments-differ + """ + Transfer a file from the host machine to the target device. + + If transfer polling is supported (ADB connections and SSH connections), + ``poll_transfers`` is set in the connection, and a timeout is not specified, + the push will be polled for activity. Inactive transfers will be + cancelled. (See :ref:`connection-types` for more information on polling). + + :param source: path on the host + :param dest: path on the target + :param as_root: whether root is required. Defaults to false. + :param timeout: timeout (in seconds) for the transfer; if the transfer does + not complete within this period, an exception will be raised. Leave unset + to utilise transfer polling if enabled. + :param globbing: If ``True``, the ``source`` is interpreted as a globbing + pattern instead of being take as-is. If the pattern has multiple + matches, ``dest`` must be a folder (or will be created as such if it + does not exists yet). + + :raises TargetStableError: If any failure occurs in copying + (e.g., insufficient permissions). + + """ source = str(source) dest = str(dest) - sources = glob.glob(source) if globbing else [source] - mapping = await self._prepare_xfer.asyn('push', sources, dest, pattern=source if globbing else None, as_root=as_root) + sources: List[str] = glob.glob(source) if globbing else [source] + mapping: Dict[List[str], str] = await self._prepare_xfer.asyn('push', sources, dest, pattern=source if globbing else None, as_root=as_root) - def do_push(sources, dest): + def do_push(sources: List[str], dest: str) -> None: for src in sources: self.async_manager.track_access( asyn.PathAccess(namespace='host', path=src, mode='r') @@ -847,7 +1209,7 @@ def do_push(sources, dest): self.async_manager.track_access( asyn.PathAccess(namespace='target', path=dest, mode='w') ) - return self.conn.push(sources, dest, timeout=timeout) + self.conn.push(sources, dest, timeout=timeout) if as_root: for sources, dest in mapping.items(): @@ -857,11 +1219,11 @@ async def f(source): await self.execute.asyn("mv -f -- {} {}".format(quote(device_tempfile), quote(dest)), as_root=True) await self.async_manager.map_concurrently(f, sources) else: - for sources, dest in mapping.items(): - do_push(sources, dest) + for sources_map, dest_map in mapping.items(): + do_push(sources_map, dest_map) @asyn.asyncf - async def _expand_glob(self, pattern, **kwargs): + async def _expand_glob(self, pattern: str, **kwargs: Dict[str, bool]) -> Optional[List[str]]: """ Expand the given path globbing pattern on the target using the shell globbing. @@ -888,27 +1250,57 @@ async def _expand_glob(self, pattern, **kwargs): pattern = pattern.replace(c, '\\' + c) cmd = "exec printf '%s\n' {}".format(pattern) - # Make sure to use the same shell everywhere for the path globbing, - # ensuring consistent results no matter what is the default platform - # shell - cmd = '{} sh -c {} 2>/dev/null'.format(quote(self.busybox), quote(cmd)) - # On some shells, match failure will make the command "return" a - # non-zero code, even though the command was not actually called - result = await self.execute.asyn(cmd, strip_colors=False, check_exit_code=False, **kwargs) - paths = result.splitlines() - if not paths: - raise TargetStableError('No file matching: {}'.format(pattern)) - - return paths + if self.busybox: + # Make sure to use the same shell everywhere for the path globbing, + # ensuring consistent results no matter what is the default platform + # shell + cmd = '{} sh -c {} 2>/dev/null'.format(quote(self.busybox), quote(cmd)) + # On some shells, match failure will make the command "return" a + # non-zero code, even though the command was not actually called + result: str = await self.execute.asyn(cmd, strip_colors=False, check_exit_code=False, **kwargs) + paths: List[str] = result.splitlines() + if not paths: + raise TargetStableError('No file matching: {}'.format(pattern)) + + return paths + return None @asyn.asyncf @call_conn - async def pull(self, source, dest, as_root=False, timeout=None, globbing=False, via_temp=False): # pylint: disable=arguments-differ + async def pull(self, source: str, dest: str, as_root: bool = False, + timeout: Optional[int] = None, globbing: bool = False, + via_temp: bool = False) -> None: # pylint: disable=arguments-differ + """ + Transfer a file from the target device to the host machine. + + If transfer polling is supported (ADB connections and SSH connections), + ``poll_transfers`` is set in the connection, and a timeout is not specified, + the pull will be polled for activity. Inactive transfers will be + cancelled. (See :ref:`connection-types` for more information on polling). + + :param source: path on the target + :param dest: path on the host + :param as_root: whether root is required. Defaults to false. + :param timeout: timeout (in seconds) for the transfer; if the transfer does + not complete within this period, an exception will be raised. + :param globbing: If ``True``, the ``source`` is interpreted as a globbing + pattern instead of being take as-is. If the pattern has multiple + matches, ``dest`` must be a folder (or will be created as such if it + does not exists yet). + :param via_temp: If ``True``, copy the file first to a temporary location on + the target, and then pull it. This can avoid issues some filesystems, + notably paramiko + OpenSSH combination having performance issues when + pulling big files from sysfs. + + :raises TargetStableError: If a transfer error occurs. + """ source = str(source) dest = str(dest) if globbing: - sources = await self._expand_glob.asyn(source, as_root=as_root) + sources: Optional[List[str]] = await self._expand_glob.asyn(source, as_root=as_root) + if sources is None: + sources = [source] else: sources = [source] @@ -916,9 +1308,9 @@ async def pull(self, source, dest, as_root=False, timeout=None, globbing=False, # so use a temporary copy instead. via_temp |= as_root - mapping = await self._prepare_xfer.asyn('pull', sources, dest, pattern=source if globbing else None, as_root=as_root) + mapping: Dict[List[str], str] = await self._prepare_xfer.asyn('pull', sources, dest, pattern=source if globbing else None, as_root=as_root) - def do_pull(sources, dest): + def do_pull(sources: List[str], dest: str) -> None: for src in sources: self.async_manager.track_access( asyn.PathAccess(namespace='target', path=src, mode='r') @@ -932,45 +1324,49 @@ def do_pull(sources, dest): for sources, dest in mapping.items(): async def f(source): async with self._xfer_cache_path(source) as device_tempfile: - cp_cmd = f"{quote(self.busybox)} cp -rL -- {quote(source)} {quote(device_tempfile)}" - chmod_cmd = f"{quote(self.busybox)} chmod 0644 -- {quote(device_tempfile)}" + cp_cmd = f"{quote(self.busybox or '')} cp -rL -- {quote(source)} {quote(device_tempfile)}" + chmod_cmd = f"{quote(self.busybox or '')} chmod 0644 -- {quote(device_tempfile)}" await self.execute.asyn(f"{cp_cmd} && {chmod_cmd}", as_root=as_root) do_pull([device_tempfile], dest) await self.async_manager.map_concurrently(f, sources) else: - for sources, dest in mapping.items(): - do_pull(sources, dest) + for sources_map, dest_map in mapping.items(): + do_pull(sources_map, dest_map) @asyn.asyncf - async def get_directory(self, source_dir, dest, as_root=False): + async def get_directory(self, source_dir: str, dest: str, + as_root: bool = False) -> None: """ Pull a directory from the device, after compressing dir """ - # Create all file names - tar_file_name = source_dir.lstrip(self.path.sep).replace(self.path.sep, '.') - # Host location of dir - outdir = os.path.join(dest, tar_file_name) - # Host location of archive - tar_file_name = '{}.tar'.format(tar_file_name) - tmpfile = os.path.join(dest, tar_file_name) + if self.path: + # Create all file names + tar_file_name: str = source_dir.lstrip(self.path.sep).replace(self.path.sep, '.') + # Host location of dir + outdir: str = os.path.join(dest, tar_file_name) + # Host location of archive + tar_file_name = '{}.tar'.format(tar_file_name) + tmpfile: str = os.path.join(dest, tar_file_name) # If root is required, use tmp location for tar creation. - tar_file_cm = self._xfer_cache_path if as_root else nullcontext + tar_file_cm: Union[Callable[[str], AsyncContextManager[str]], Callable[[str], nullcontext]] = self._xfer_cache_path if as_root else nullcontext # Does the folder exist? await self.execute.asyn('ls -la {}'.format(quote(source_dir)), as_root=as_root) - async with tar_file_cm(tar_file_name) as tar_file_name: + async with tar_file_cm(tar_file_name) as tar_file: # Try compressing the folder try: - await self.execute.asyn('{} tar -cvf {} {}'.format( - quote(self.busybox), quote(tar_file_name), quote(source_dir) - ), as_root=as_root) + # FIXME - should we raise an error in the else case here when busybox or tar_file is None + if self.busybox and tar_file: + await self.execute.asyn('{} tar -cvf {} {}'.format( + quote(self.busybox), quote(tar_file), quote(source_dir) + ), as_root=as_root) except TargetStableError: - self.logger.debug('Failed to run tar command on target! ' \ - 'Not pulling directory {}'.format(source_dir)) + self.logger.debug('Failed to run tar command on target! ' + 'Not pulling directory {}'.format(source_dir)) # Pull the file if not os.path.exists(dest): os.mkdir(dest) - await self.pull.asyn(tar_file_name, tmpfile) + await self.pull.asyn(tar_file, tmpfile) # Decompress with tarfile.open(tmpfile, 'r') as f: safe_extract(f, outdir) @@ -978,18 +1374,26 @@ async def get_directory(self, source_dir, dest, as_root=False): # execution - def _prepare_cmd(self, command, force_locale): + def _prepare_cmd(self, command: SubprocessCommand, force_locale: str) -> SubprocessCommand: + """ + Internal helper to prepend environment settings (e.g., PATH, locale) + to a command string before execution. + + :param command: The command to execute. + :param force_locale: The locale to enforce (e.g. 'C') or None for none. + :return: The updated command string with environment preparation. + """ tmpdir = f'TMPDIR={quote(self.tmp_directory)}' if self.tmp_directory else '' # Force the locale if necessary for more predictable output if force_locale: # Use an explicit export so that the command is allowed to be any # shell statement, rather than just a command invocation - command = f'export LC_ALL={quote(force_locale)} {tmpdir} && {command}' + command = f'export LC_ALL={quote(force_locale)} {tmpdir} && {cast(str, command)}' # Ensure to use deployed command when availables if self.executables_directory: - command = f"export PATH={quote(self.executables_directory)}:$PATH && {command}" + command = f"export PATH={quote(self.executables_directory)}:$PATH && {cast(str, command)}" return command @@ -998,18 +1402,30 @@ class _BrokenConnection(Exception): @asyn.asyncf @call_conn - async def _execute_async(self, *args, **kwargs): + async def _execute_async(self, *args: Any, **kwargs: Any) -> str: + """ + Internal asynchronous handler for command execution. + + This is typically invoked by the asynchronous version of :meth:`execute`. + It may create a background thread or use an existing thread pool + to run the blocking command. + + :param args: Positional arguments forwarded to the blocking command. + :param kwargs: Keyword arguments forwarded to the blocking command. + :return: The stdout of the command executed. + :raises DevlibError: If any error occurs during command execution. + """ execute = functools.partial( self._execute, *args, **kwargs ) - pool = self._async_pool + pool: Optional[ThreadPoolExecutor] = self._async_pool if pool is None: return execute() else: - def thread_f(): + def thread_f() -> str: # If we cannot successfully connect from the thread, it might # mean that something external opened a connection on the # target, so we just revert to the blocking path. @@ -1020,21 +1436,45 @@ def thread_f(): else: return execute() - loop = asyncio.get_running_loop() + loop: AbstractEventLoop = asyncio.get_running_loop() try: return await loop.run_in_executor(pool, thread_f) except self._BrokenConnection: return execute() @call_conn - def _execute(self, command, timeout=None, check_exit_code=True, - as_root=False, strip_colors=True, will_succeed=False, - force_locale='C'): + def _execute(self, command: SubprocessCommand, timeout: Optional[int] = None, check_exit_code: bool = True, + as_root: bool = False, strip_colors: bool = True, will_succeed: bool = False, + force_locale: str = 'C') -> str: + """ + Internal blocking command executor. Actual synchronous logic is placed here, + usually invoked by :meth:`execute`. + + :param command: The command to be executed. + :param timeout: Timeout (in seconds) for the execution of the command. If + specified, an exception will be raised if execution does not complete + with the specified period. + :param check_exit_code: If ``True`` (the default) the exit code (on target) + from execution of the command will be checked, and an exception will be + raised if it is not ``0``. + :param as_root: The command will be executed as root. This will fail on + unrooted targets. + :param strip_colours: The command output will have colour encodings and + most ANSI escape sequences striped out before returning. + :param will_succeed: The command is assumed to always succeed, unless there is + an issue in the environment like the loss of network connectivity. That + will make the method always raise an instance of a subclass of + :class:`DevlibTransientError` when the command fails, instead of a + :class:`DevlibStableError`. + :param force_locale: Prepend ``LC_ALL=`` in front of the + command to get predictable output that can be more safely parsed. + If ``None``, no locale is prepended. + """ command = self._prepare_cmd(command, force_locale) return self.conn.execute(command, timeout=timeout, - check_exit_code=check_exit_code, as_root=as_root, - strip_colors=strip_colors, will_succeed=will_succeed) + check_exit_code=check_exit_code, as_root=as_root, + strip_colors=strip_colors, will_succeed=will_succeed) execute = asyn._AsyncPolymorphicFunction( asyn=_execute_async.asyn, @@ -1042,38 +1482,62 @@ def _execute(self, command, timeout=None, check_exit_code=True, ) @call_conn - def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False, - force_locale='C', timeout=None): - conn = self.conn + def background(self, command: SubprocessCommand, stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, as_root: Optional[bool] = False, + force_locale: str = 'C', timeout: Optional[int] = None) -> BackgroundCommand: + """ + Execute the command on the target, invoking it via subprocess on the host. + This will return :class:`subprocess.Popen` instance for the command. + + :param command: The command to be executed. + :param stdout: By default, standard output will be piped from the subprocess; + this may be used to redirect it to an alternative file handle. + :param stderr: By default, standard error will be piped from the subprocess; + this may be used to redirect it to an alternative file handle. + :param as_root: The command will be executed as root. This will fail on + unrooted targets. + :param force_locale: Prepend ``LC_ALL=`` in front of the + command to get predictable output that can be more safely parsed. + If ``None``, no locale is prepended. + :param timeout: Timeout (in seconds) for the execution of the command. When + the timeout expires, :meth:`BackgroundCommand.cancel` is executed to + terminate the command. + + :return: A handle to the background command. + + .. note:: This **will block the connection** until the command completes. + """ command = self._prepare_cmd(command, force_locale) - bg_cmd = self.conn.background(command, stdout, stderr, as_root) + bg_cmd: BackgroundCommand = self.conn.background(command, stdout, stderr, as_root) if timeout is not None: timer = threading.Timer(timeout, function=bg_cmd.cancel) timer.daemon = True timer.start() return bg_cmd - def invoke(self, binary, args=None, in_directory=None, on_cpus=None, - redirect_stderr=False, as_root=False, timeout=30): + def invoke(self, binary: str, args: Optional[Union[str, Iterable[str]]] = None, in_directory: Optional[str] = None, + on_cpus: Optional[Union[int, List[int], str]] = None, redirect_stderr: bool = False, as_root: bool = False, + timeout: Optional[int] = 30) -> str: """ Executes the specified binary under the specified conditions. - :binary: binary to execute. Must be present and executable on the device. - :args: arguments to be passed to the binary. The can be either a list or + :param binary: binary to execute. Must be present and executable on the device. + :param args: arguments to be passed to the binary. The can be either a list or a string. - :in_directory: execute the binary in the specified directory. This must + :param in_directory: execute the binary in the specified directory. This must be an absolute path. - :on_cpus: taskset the binary to these CPUs. This may be a single ``int`` (in which + :param on_cpus: taskset the binary to these CPUs. This may be a single ``int`` (in which case, it will be interpreted as the mask), a list of ``ints``, in which case this will be interpreted as the list of cpus, or string, which will be interpreted as a comma-separated list of cpu ranges, e.g. ``"0,4-7"``. - :as_root: Specify whether the command should be run as root - :timeout: If the invocation does not terminate within this number of seconds, + :param redirect_stderr: redirect stderr to stdout + :param as_root: Specify whether the command should be run as root + :param timeout: If the invocation does not terminate within this number of seconds, a ``TimeoutError`` exception will be raised. Set to ``None`` if the invocation should not timeout. - :returns: output of command. + :return: The captured output of the command. """ command = binary if args: @@ -1081,33 +1545,36 @@ def invoke(self, binary, args=None, in_directory=None, on_cpus=None, args = ' '.join(args) command = '{} {}'.format(command, args) if on_cpus: - on_cpus = bitmask(on_cpus) - command = '{} taskset 0x{:x} {}'.format(quote(self.busybox), on_cpus, command) + on_cpus_bitmask = bitmask(on_cpus) + if self.busybox: + command = '{} taskset 0x{:x} {}'.format(quote(self.busybox), on_cpus_bitmask, command) if in_directory: command = 'cd {} && {}'.format(quote(in_directory), command) if redirect_stderr: command = '{} 2>&1'.format(command) return self.execute(command, as_root=as_root, timeout=timeout) - def background_invoke(self, binary, args=None, in_directory=None, - on_cpus=None, as_root=False): + def background_invoke(self, binary: str, args: Optional[Union[str, Iterable[str]]] = None, in_directory: Optional[str] = None, + on_cpus: Optional[Union[int, List[int], str]] = None, as_root: bool = False) -> BackgroundCommand: """ - Executes the specified binary as a background task under the - specified conditions. + Runs the specified binary as a background task, possibly pinned to CPUs or + launched in a certain directory. - :binary: binary to execute. Must be present and executable on the device. - :args: arguments to be passed to the binary. The can be either a list or + :param binary: binary to execute. Must be present and executable on the device. + :param args: arguments to be passed to the binary. The can be either a list or a string. - :in_directory: execute the binary in the specified directory. This must + :param in_directory: execute the binary in the specified directory. This must be an absolute path. - :on_cpus: taskset the binary to these CPUs. This may be a single ``int`` (in which + :param on_cpus: taskset the binary to these CPUs. This may be a single ``int`` (in which case, it will be interpreted as the mask), a list of ``ints``, in which case this will be interpreted as the list of cpus, or string, which will be interpreted as a comma-separated list of cpu ranges, e.g. ``"0,4-7"``. - :as_root: Specify whether the command should be run as root + :param as_root: Specify whether the command should be run as root :returns: the subprocess instance handling that command + + :raises TargetError: If the binary does not exist or is not executable. """ command = binary if args: @@ -1115,66 +1582,144 @@ def background_invoke(self, binary, args=None, in_directory=None, args = ' '.join(args) command = '{} {}'.format(command, args) if on_cpus: - on_cpus = bitmask(on_cpus) - command = '{} taskset 0x{:x} {}'.format(quote(self.busybox), on_cpus, command) + on_cpus_bitmask = bitmask(on_cpus) + if self.busybox: + command = '{} taskset 0x{:x} {}'.format(quote(self.busybox), on_cpus_bitmask, command) + else: + raise TargetStableError("busybox not set. cannot execute command") if in_directory: command = 'cd {} && {}'.format(quote(in_directory), command) return self.background(command, as_root=as_root) @asyn.asyncf - async def kick_off(self, command, as_root=None): + async def kick_off(self, command: str, as_root: Optional[bool] = None) -> None: """ - Like execute() but returns immediately. Unlike background(), it will - not return any handle to the command being run. + Kick off the specified command on the target and return immediately. Unlike + ``background()`` this will not block the connection; on the other hand, there + is not way to know when the command finishes (apart from calling ``ps()``) + or to get its output (unless its redirected into a file that can be pulled + later as part of the command). + + :param command: The command to be executed. + :param as_root: The command will be executed as root. This will fail on + unrooted targets. + + :raises TargetError: If the command cannot be launched. """ - cmd = 'cd {wd} && {busybox} sh -c {cmd} >/dev/null 2>&1'.format( - wd=quote(self.working_directory), - busybox=quote(self.busybox), - cmd=quote(command) - ) + if self.working_directory and self.busybox: + cmd = 'cd {wd} && {busybox} sh -c {cmd} >/dev/null 2>&1'.format( + wd=quote(self.working_directory), + busybox=quote(self.busybox), + cmd=quote(command) + ) + else: + raise TargetStableError("working directory or busybox not set. cannot kick off command") self.background(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, as_root=as_root) - # sysfs interaction + R = TypeVar('R') @asyn.asyncf - async def read_value(self, path, kind=None): + async def read_value(self, path: str, kind: Optional[Callable[[str], R]] = None) -> Union[str, R]: + """ + Read the value from the specified path. This is primarily intended for + sysfs/procfs/debugfs etc. + + :param path: file to read + :param kind: Optionally, read value will be converted into the specified + kind (which should be a callable that takes exactly one parameter) + + :return: The contents of the file, possibly parsed via ``kind``. + :raises TargetStableError: If the file does not exist or is unreadable. + """ self.async_manager.track_access( asyn.PathAccess(namespace='target', path=path, mode='r') ) - output = await self.execute.asyn('cat {}'.format(quote(path)), as_root=self.needs_su) # pylint: disable=E1103 + output: str = await self.execute.asyn('cat {}'.format(quote(path)), as_root=self.needs_su) # pylint: disable=E1103 output = output.strip() - if kind: - return kind(output) + if kind and callable(kind) and output: + try: + return kind(output) + except Exception as e: + raise ValueError(f"Error converting output using {kind}: {e}") else: return output @asyn.asyncf - async def read_int(self, path): + async def read_int(self, path: str) -> int: + """ + Equivalent to ``Target.read_value(path, kind=devlib.utils.types.integer)`` + + :param path: The file path to read. + :return: The integer value contained in the file. + :raises ValueError: If the file contents cannot be parsed as an integer. + """ return await self.read_value.asyn(path, kind=integer) @asyn.asyncf - async def read_bool(self, path): + async def read_bool(self, path: str) -> bool: + """ + Equivalent to ``Target.read_value(path, kind=devlib.utils.types.boolean)`` + + :param path: File path to read. + :return: True or False, parsed from the file content. + :raises ValueError: If the file contents cannot be interpreted as a boolean. + """ return await self.read_value.asyn(path, kind=boolean) @asyn.asynccontextmanager - async def revertable_write_value(self, path, value, verify=True, as_root=True): - orig_value = self.read_value(path) + async def revertable_write_value(self, path: str, value: Any, verify: bool = True, as_root: bool = True) -> AsyncGenerator: + """ + Same as :meth:`Target.write_value`, but as a context manager that will write + back the previous value on exit. + + :param path: The file path to write to on the target. + :param value: The value to write, converted to a string. + :param verify: If True, read the file back to confirm the change. + :param as_root: If True, write as root. + :yield: Allows running code in the context while the value is changed. + """ + orig_value: str = self.read_value(path) try: await self.write_value.asyn(path, value, verify=verify, as_root=as_root) yield finally: await self.write_value.asyn(path, orig_value, verify=verify, as_root=as_root) - def batch_revertable_write_value(self, kwargs_list): + def batch_revertable_write_value(self, kwargs_list: List[Dict[str, Any]]) -> '_GeneratorContextManager': + """ + Calls :meth:`Target.revertable_write_value` with all the keyword arguments + dictionary given in the list. This is a convenience method to update + multiple files at once, leaving them in their original state on exit. If one + write fails, all the already-performed writes will be reverted as well. + + :param kwargs_list: A list of dicts, each containing the kwargs for + :meth:`revertable_write_value`, e.g., {"path": , "value": , ...}. + :return: A context manager that applies all writes on entry, then reverts them. + """ return batch_contextmanager(self.revertable_write_value, kwargs_list) @asyn.asyncf - async def write_value(self, path, value, verify=True, as_root=True): + async def write_value(self, path: str, value: Any, verify: bool = True, as_root: bool = True) -> None: + """ + Write the value to the specified path on the target. This is primarily + intended for sysfs/procfs/debugfs etc. + + :param path: file to write into + :param value: value to be written + :param verify: If ``True`` (the default) the value will be read back after + it is written to make sure it has been written successfully. This due to + some sysfs entries silently failing to set the written value without + returning an error code. + :param as_root: specifies if writing requires being root. Its default value + is ``True``. + + :raises TargetStableError: If the write or verification fails. + """ self.async_manager.track_access( asyn.PathAccess(namespace='target', path=path, mode='w') ) - value = str(value) + string_value = str(value) if verify: # Check in a loop for a while since updates to sysfs files can take @@ -1184,10 +1729,10 @@ async def write_value(self, path, value, verify=True, as_root=True): # request, such as hotplugging a CPU. cmd = ''' orig=$(cat {path} 2>/dev/null || printf "") -printf "%s" {value} > {path} || exit 10 -if [ {value} != "$orig" ]; then +printf "%s" {string_value} > {path} || exit 10 +if [ {string_value} != "$orig" ]; then trials=0 - while [ "$(cat {path} 2>/dev/null)" != {value} ]; do + while [ "$(cat {path} 2>/dev/null)" != {string_value} ]; do if [ $trials -ge 10 ]; then cat {path} exit 11 @@ -1198,56 +1743,59 @@ async def write_value(self, path, value, verify=True, as_root=True): fi ''' else: - cmd = '{busybox} printf "%s" {value} > {path}' - cmd = cmd.format(busybox=quote(self.busybox), path=quote(path), value=quote(value)) + cmd = '{busybox} printf "%s" {string_value} > {path}' + if self.busybox: + cmd = cmd.format(busybox=quote(self.busybox), path=quote(path), string_value=quote(string_value)) try: await self.execute.asyn(cmd, check_exit_code=True, as_root=as_root) except TargetCalledProcessError as e: if e.returncode == 10: - raise TargetStableError('Could not write "{value}" to {path}: {e.output}'.format( - value=value, path=path, e=e)) + raise TargetStableError('Could not write "{string_value}" to {path}: {e.output}'.format( + string_value=string_value, path=path, e=e)) elif verify and e.returncode == 11: out = e.output - message = 'Could not set the value of {} to "{}" (read "{}")'.format(path, value, out) + message = 'Could not set the value of {} to "{}" (read "{}")'.format(path, string_value, out) raise TargetStableError(message) else: raise @asyn.asynccontextmanager - async def make_temp(self, is_directory=True, directory=None, prefix=None): + async def make_temp(self, is_directory: Optional[bool] = True, directory: Optional[str] = None, + prefix: Optional[str] = None) -> AsyncGenerator: """ Creates temporary file/folder on target and deletes it once it's done. :param is_directory: Specifies if temporary object is a directory, defaults to True. - :type is_directory: bool or None :param directory: Temp object will be created under this directory, defaults to ``Target.working_directory``. - :type directory: str or None :param prefix: Prefix of temp object's name. - :type prefix: str or None :yield: Full path of temp object. - :rtype: str """ directory = directory or self.tmp_directory prefix = f'{prefix}-' if prefix else '' temp_obj = None try: - cmd = f'mktemp -p {quote(directory)} {quote(prefix)}XXXXXX' - if is_directory: - cmd += ' -d' + if directory is not None: + cmd = f'mktemp -p {quote(directory)} {quote(prefix)}XXXXXX' + if is_directory: + cmd += ' -d' - temp_obj = (await self.execute.asyn(cmd)).strip() - yield temp_obj + temp_obj = (await self.execute.asyn(cmd)).strip() + yield temp_obj finally: if temp_obj is not None: await self.remove.asyn(temp_obj) - def reset(self): + def reset(self) -> None: + """ + Soft reset the target. Typically, this means executing ``reboot`` on the + target. + """ try: self.execute('reboot', as_root=self.needs_su, timeout=2) except (TargetError, subprocess.CalledProcessError): @@ -1256,7 +1804,11 @@ def reset(self): self.conn.connected_as_root = None @call_conn - def check_responsive(self, explode=True): + def check_responsive(self, explode: bool = True) -> bool: + """ + Returns ``True`` if the target appears to be responsive and ``False`` + otherwise. + """ try: self.conn.execute('ls /', timeout=5) return True @@ -1267,51 +1819,96 @@ def check_responsive(self, explode=True): # process management - def kill(self, pid, signal=None, as_root=False): + def kill(self, pid: int, signal: Optional[signal.Signals] = None, as_root: Optional[bool] = False) -> None: + """ + Send a signal (default SIGTERM) to a process by PID. + + :param pid: The PID of the process to kill. + :param signal: The signal to send (e.g., signal.SIGKILL). + :param as_root: If True, run the kill command as root. + """ signal_string = '-s {}'.format(signal) if signal else '' self.execute('{} kill {} {}'.format(self.busybox, signal_string, pid), as_root=as_root) - def killall(self, process_name, signal=None, as_root=False): + def killall(self, process_name: str, signal: Optional[signal.Signals] = None, + as_root: Optional[bool] = False) -> None: + """ + Send a signal to all processes matching the given name. + + :param process_name: Name of processes to kill. + :param signal: The signal to send. + :param as_root: If True, run the kill command as root. + """ for pid in self.get_pids_of(process_name): try: self.kill(pid, signal=signal, as_root=as_root) except TargetStableError: pass - def get_pids_of(self, process_name): + def get_pids_of(self, process_name: str) -> List[int]: + """ + Return a list of PIDs of all running instances of the specified process. + """ raise NotImplementedError() - def ps(self, **kwargs): + def ps(self, **kwargs: Dict[str, Any]) -> List['PsEntry']: + """ + Return a list of :class:`PsEntry` instances for all running processes on the + system. + """ raise NotImplementedError() # files @asyn.asyncf - async def makedirs(self, path, as_root=False): + async def makedirs(self, path: str, as_root: bool = False) -> None: + """ + Create a directory (and its parents if needed) on the target. + + :param path: Directory path to create. + :param as_root: If True, create as root. + """ await self.execute.asyn('mkdir -p {}'.format(quote(path)), as_root=as_root) @asyn.asyncf - async def file_exists(self, filepath): + async def file_exists(self, filepath: str) -> bool: + """ + Check if a file or directory exists at the specified path. + + :param filepath: The target path to check. + :return: True if the path exists on the target, else False. + """ command = 'if [ -e {} ]; then echo 1; else echo 0; fi' - output = await self.execute.asyn(command.format(quote(filepath)), as_root=self.is_rooted) + output: str = await self.execute.asyn(command.format(quote(filepath)), as_root=self.is_rooted) return boolean(output.strip()) @asyn.asyncf - async def directory_exists(self, filepath): + async def directory_exists(self, filepath: str) -> bool: + """ + Check if the path on the target is an existing directory. + + :param filepath: The path to check. + :return: True if a directory exists at the path, else False. + """ output = await self.execute.asyn('if [ -d {} ]; then echo 1; else echo 0; fi'.format(quote(filepath))) # output from ssh my contain part of the expression in the buffer, # split out everything except the last word. return boolean(output.split()[-1]) # pylint: disable=maybe-no-member @asyn.asyncf - async def list_file_systems(self): - output = await self.execute.asyn('mount') - fstab = [] + async def list_file_systems(self) -> List['FstabEntry']: + """ + Return a list of currently mounted file systems, parsed into FstabEntry objects. + + :return: A list of file system entries describing mount points. + """ + output: str = await self.execute.asyn('mount') + fstab: List['FstabEntry'] = [] for line in output.split('\n'): line = line.strip() if not line: continue - match = FSTAB_ENTRY_REGEX.search(line) + match: Optional[Match[str]] = FSTAB_ENTRY_REGEX.search(line) if match: fstab.append(FstabEntry(match.group(1), match.group(2), match.group(3), match.group(4), @@ -1321,20 +1918,53 @@ async def list_file_systems(self): return fstab @asyn.asyncf - async def list_directory(self, path, as_root=False): + async def list_directory(self, path: str, as_root: bool = False) -> List[str]: + """ + Internal method that returns the contents of a directory. Called by + :meth:`list_directory`. + + :param path: Directory path to list. + :param as_root: If True, list as root. + :return: A list of filenames within the directory. + :raises NotImplementedError: If not implemented in a subclass. + """ self.async_manager.track_access( asyn.PathAccess(namespace='target', path=path, mode='r') ) return await self._list_directory(path, as_root=as_root) - def _list_directory(self, path, as_root=False): + async def _list_directory(self, path: str, as_root: bool = False) -> List[str]: + """ + List the contents of the specified directory. Optionally run as root. + + :param path: Directory path to list. + :param as_root: If True, run the directory listing as root. + :return: Names of entries in the directory. + :raises TargetStableError: If the path is not a directory or is unreadable. + """ raise NotImplementedError() - def get_workpath(self, name): - return self.path.join(self.working_directory, name) + def get_workpath(self, name: str) -> Optional[str]: + """ + Join a name with :attr:`working_directory` on the target, returning + an absolute path for convenience. + + :param name: The filename to append to the working directory. + :return: The combined absolute path, or None if no working directory is set. + """ + if self.path: + return self.path.join(self.working_directory, name) + return None @asyn.asyncf - async def tempfile(self, prefix=None, suffix=None): + async def tempfile(self, prefix: Optional[str] = None, suffix: Optional[str] = None) -> Optional[str]: + """ + Generate a unique path for a temporary file in the :attr:`working_directory`. + + :param prefix: An optional prefix for the file name. + :param suffix: An optional suffix (e.g. ".txt"). + :return: The full path to the file, which does not yet exist. + """ prefix = f'{prefix}-' if prefix else '' suffix = f'-{suffix}' if suffix else '' name = '{prefix}{uuid}{suffix}'.format( @@ -1342,99 +1972,212 @@ async def tempfile(self, prefix=None, suffix=None): uuid=uuid.uuid4().hex, suffix=suffix, ) - path = self.path.join(self.tmp_directory, name) + path = self.path.join(self.tmp_directory, name) if self.path else '' if (await self.file_exists.asyn(path)): raise FileExistsError('Path already exists on the target: {}'.format(path)) else: return path @asyn.asyncf - async def remove(self, path, as_root=False): + async def remove(self, path: str, as_root=False) -> None: + """ + Remove a file or directory on the target. + + :param path: Path to remove. + :param as_root: If True, remove as root. + """ await self.execute.asyn('rm -rf -- {}'.format(quote(path)), as_root=as_root) # misc @asyn.asyncf - async def read_sysctl(self, parameter): + async def read_sysctl(self, parameter: str) -> Optional[str]: """ - Returns the value of the given sysctl parameter as a string. + Read the specified sysctl parameter. Equivalent to reading the file under + ``/proc/sys/...``. + + :param parameter: The sysctl name, e.g. "kernel.sched_latency_ns". + :return: The value of the sysctl parameter, or None if not found. + :raises ValueError: If the sysctl parameter doesn't exist. """ - path = self.path.join('/', 'proc', 'sys', *parameter.split('.')) - try: - return await self.read_value.asyn(path) - except FileNotFoundError as e: - raise ValueError(f'systcl parameter {parameter} was not found: {e}') + if self.path: + path: str = self.path.join('/', 'proc', 'sys', *parameter.split('.')) + try: + return await self.read_value.asyn(path) + except FileNotFoundError as e: + raise ValueError(f'systcl parameter {parameter} was not found: {e}') + return None + + def core_cpus(self, core: str) -> List[int]: + """ + Return numeric CPU IDs corresponding to the given core name. - def core_cpus(self, core): + :param core: The name of the CPU core (e.g., "A53"). + :return: List of CPU indices that match the given name. + """ return [i for i, c in enumerate(self.core_names) if c == core] @asyn.asyncf - async def list_online_cpus(self, core=None): - path = self.path.join('/sys/devices/system/cpu/online') - output = await self.read_value.asyn(path) - all_online = ranges_to_list(output) - if core: - cpus = self.core_cpus(core) - if not cpus: - raise ValueError(core) - return [o for o in all_online if o in cpus] - else: - return all_online + async def list_online_cpus(self, core: Optional[str] = None) -> Optional[List[int]]: + """ + Return a list of online CPU IDs. If a core name is provided, restricts + to CPUs that match that name. + + :param core: Optional name of the CPU core (e.g., "A53") to filter results. + :return: Online CPU IDs. + :raises ValueError: If the specified core name is invalid. + """ + if self.path: + path: str = self.path.join('/sys/devices/system/cpu/online') + output: str = await self.read_value.asyn(path) + all_online: List[int] = ranges_to_list(output) + if core: + cpus: List[int] = self.core_cpus(core) + if not cpus: + raise ValueError(core) + return [o for o in all_online if o in cpus] + else: + return all_online + return None @asyn.asyncf - async def list_offline_cpus(self): - online = await self.list_online_cpus.asyn() + async def list_offline_cpus(self) -> List[int]: + """ + Return a list of offline CPU IDs, i.e., those not present in + :meth:`list_online_cpus`. + + :return: Offline CPU IDs. + """ + online: List[int] = await self.list_online_cpus.asyn() return [c for c in range(self.number_of_cpus) if c not in online] @asyn.asyncf - async def getenv(self, variable): - var = await self.execute.asyn('printf "%s" ${}'.format(variable)) + async def getenv(self, variable: str) -> str: + """ + Return the value of the specified environment variable on the device + """ + var: str = await self.execute.asyn('printf "%s" ${}'.format(variable)) return var.rstrip('\r\n') - def capture_screen(self, filepath): + def capture_screen(self, filepath: str) -> None: + """ + Take a screenshot on the device and save it to the specified file on the + host. This may not be supported by the target. You can optionally insert a + ``{ts}`` tag into the file name, in which case it will be substituted with + on-target timestamp of the screen shot in ISO8601 format. + + :param filepath: Path on the host where screenshot is stored. + :raises NotImplementedError: If screenshot capture is not implemented. + """ raise NotImplementedError() - def install(self, filepath, timeout=None, with_name=None): + @asyn.asyncf + def install(self, filepath: str, timeout: Optional[int] = None, with_name: Optional[str] = None) -> str: + """ + Install an executable from the host to the target. If `with_name` is given, + the file is renamed on the target. + + :param filepath: Path on the host to the executable. + :param timeout: Timeout in seconds for the installation. + :param with_name: If provided, rename the installed file on the target. + :return: The path to the installed binary on the target. + :raises NotImplementedError: If not implemented in a subclass. + """ raise NotImplementedError() - def uninstall(self, name): + def uninstall(self, name: str) -> None: + """ + Uninstall a previously installed executable. + + :param name: Name of the executable to remove. + :raises NotImplementedError: If not implemented in a subclass. + """ raise NotImplementedError() @asyn.asyncf - async def get_installed(self, name, search_system_binaries=True): + async def get_installed(self, name: str, search_system_binaries: bool = True) -> Optional[str]: + """ + Return the absolute path of an installed executable with the given name, + or None if not found. + + :param name: The name of the binary. + :param search_system_binaries: If True, also search the system PATH. + :return: Full path to the binary on the target, or None if not found. + """ # Check user installed binaries first if self.file_exists(self.executables_directory): - if name in (await self.list_directory.asyn(self.executables_directory)): + if name in (await self.list_directory.asyn(self.executables_directory)) and self.path: return self.path.join(self.executables_directory, name) # Fall back to binaries in PATH if search_system_binaries: - PATH = await self.getenv.asyn('PATH') - for path in PATH.split(self.path.pathsep): - try: - if name in (await self.list_directory.asyn(path)): - return self.path.join(path, name) - except TargetStableError: - pass # directory does not exist or no executable permissions + PATH: str = await self.getenv.asyn('PATH') + if self.path: + for path in PATH.split(self.path.pathsep): + try: + if name in (await self.list_directory.asyn(path)): + return self.path.join(path, name) + except TargetStableError: + pass # directory does not exist or no executable permissions + return None - which = get_installed + which: '_AsyncPolymorphicFunction' = get_installed @asyn.asyncf - async def install_if_needed(self, host_path, search_system_binaries=True, timeout=None): - - binary_path = await self.get_installed.asyn(os.path.split(host_path)[1], - search_system_binaries=search_system_binaries) + async def install_if_needed(self, host_path: str, search_system_binaries: bool = True, + timeout: Optional[int] = None) -> str: + """ + Check whether an executable with the name of ``host_path`` is already installed + on the target. If it is not installed, install it from the specified path. + + :param host_path: The path to the executable on the host system. + :param search_system_binaries: If ``True``, also search the device's system PATH + for the binary before installing. If ``False``, only check user-installed + binaries. + :param timeout: Maximum time in seconds to wait for installation to complete. + If ``None``, a default (implementation-defined) timeout is used. + :return: The absolute path of the binary on the target after ensuring it is installed. + + :raises TargetError: If the target is disconnected. + :raises TargetStableError: If installation fails or times out (depending on implementation). + """ + binary_path: str = await self.get_installed.asyn(os.path.split(host_path)[1], + search_system_binaries=search_system_binaries) if not binary_path: binary_path = await self.install.asyn(host_path, timeout=timeout) return binary_path @asyn.asyncf - async def is_installed(self, name): + async def is_installed(self, name: str) -> bool: + """ + Determine whether an executable with the specified name is installed on the target. + + :param name: Name of the executable (e.g. "perf"). + :return: ``True`` if the executable is found, otherwise ``False``. + + :raises TargetError: If the target is not currently connected. + """ return bool(await self.get_installed.asyn(name)) - def bin(self, name): + def bin(self, name: str) -> str: + """ + Retrieve the installed path to the specified binary on the target. + + :param name: Name of the binary whose path is requested. + :return: The path to the binary if installed and recorded by devlib, + otherwise returns ``name`` unmodified. + """ return self._installed_binaries.get(name, name) - def has(self, modname): + def has(self, modname: str) -> bool: + """ + Check whether the specified module or feature is present on the target. + + :param modname: Module name to look up. + :return: ``True`` if the module is present and loadable, otherwise ``False``. + + :raises Exception: If an unexpected error occurs while querying the module. + (Can be replaced with a more specific exception if desired.) + """ modname = identifier(modname) try: self._get_module(modname, log=False) @@ -1444,9 +2187,15 @@ def has(self, modname): return True @asyn.asyncf - async def lsmod(self): - lines = (await self.execute.asyn('lsmod')).splitlines() - entries = [] + async def lsmod(self) -> List['LsmodEntry']: + """ + Run the ``lsmod`` command on the target and return the result as a list + of :class:`LsmodEntry` namedtuples. + + :return: A list of loaded kernel modules, each represented by an LsmodEntry object. + """ + lines: str = (await self.execute.asyn('lsmod')).splitlines() + entries: List['LsmodEntry'] = [] for line in lines[1:]: # first line is the header if not line.strip(): continue @@ -1459,13 +2208,20 @@ async def lsmod(self): return entries @asyn.asyncf - async def insmod(self, path): - target_path = self.get_workpath(os.path.basename(path)) + async def insmod(self, path: str) -> None: + """ + Insert a kernel module onto the target via ``insmod``. + + :param path: The path on the *host* system to the kernel module file (.ko). + :raises TargetStableError: If the module cannot be inserted (e.g., missing dependencies). + """ + target_path: Optional[str] = self.get_workpath(os.path.basename(path)) await self.push.asyn(path, target_path) - await self.execute.asyn('insmod {}'.format(quote(target_path)), as_root=True) + if target_path: + await self.execute.asyn('insmod {}'.format(quote(target_path)), as_root=True) @asyn.asyncf - async def extract(self, path, dest=None): + async def extract(self, path: str, dest: Optional[str] = None) -> Optional[str]: """ Extract the specified on-target file. The extraction method to be used (unzip, gunzip, bunzip2, or tar) will be based on the file's extension. @@ -1482,39 +2238,77 @@ async def extract(self, path, dest=None): (``dest`` if it was specified otherwise, the directory that contained the archive). + :param path: The on-target path of the archive or compressed file. + :param dest: An optional directory path on the target where the contents + should be extracted. The directory must already exist. + :return: Path to the extracted files. + * If a multi-file archive, returns the directory containing those files. + * If a single-file compression (e.g., .gz, .bz2), returns the path to + the decompressed file. + * If extraction fails or is unknown format, ``None`` might be returned + (depending on your usage). + + :raises ValueError: If the file’s format is unrecognized. + :raises TargetStableError: If extraction fails on the target. """ for ending in ['.tar.gz', '.tar.bz', '.tar.bz2', '.tgz', '.tbz', '.tbz2']: if path.endswith(ending): return await self._extract_archive(path, 'tar xf {} -C {}', dest) - ext = self.path.splitext(path)[1] - if ext in ['.bz', '.bz2']: - return await self._extract_file(path, 'bunzip2 -f {}', dest) - elif ext == '.gz': - return await self._extract_file(path, 'gunzip -f {}', dest) - elif ext == '.zip': - return await self._extract_archive(path, 'unzip {} -d {}', dest) - else: - raise ValueError('Unknown compression format: {}'.format(ext)) + if self.path: + ext: str = self.path.splitext(path)[1] + if ext in ['.bz', '.bz2']: + return await self._extract_file(path, 'bunzip2 -f {}', dest) + elif ext == '.gz': + return await self._extract_file(path, 'gunzip -f {}', dest) + elif ext == '.zip': + return await self._extract_archive(path, 'unzip {} -d {}', dest) + else: + raise ValueError('Unknown compression format: {}'.format(ext)) + return None @asyn.asyncf - async def sleep(self, duration): + async def sleep(self, duration: int) -> None: + """ + Invoke a ``sleep`` command on the target to pause for the specified duration. + + :param duration: The time in seconds the target should sleep. + :raises TimeoutError: If the sleep operation times out (rare, but can be forced). + """ timeout = duration + 10 await self.execute.asyn('sleep {}'.format(duration), timeout=timeout) @asyn.asyncf - async def read_tree_tar_flat(self, path, depth=1, check_exit_code=True, - decode_unicode=True, strip_null_chars=True): + async def read_tree_tar_flat(self, path: str, depth: int = 1, check_exit_code: bool = True, + decode_unicode: bool = True, strip_null_chars: bool = True) -> Dict[str, str]: + """ + Recursively read file nodes within a tar archive stored on the target, up to + a given ``depth``. The archive is temporarily extracted in memory, and the + contents are returned in a flat dictionary mapping each file path to its content. + + :param path: Path to the tar archive on the target. + :param depth: Maximum directory depth to traverse within the archive. + :param check_exit_code: If ``True``, raise an error if the helper command exits non-zero. + :param decode_unicode: If ``True``, attempt to decode each file’s content as UTF-8. + :param strip_null_chars: If ``True``, strip out any null characters (``\\x00``) from + decoded text. + :return: A dictionary mapping file paths (within the archive) to their textual content. + + :raises TargetStableError: If the helper command fails or returns unexpected data. + :raises UnicodeDecodeError: If a file's content cannot be decoded when + ``decode_unicode=True``. + """ self.async_manager.track_access( asyn.PathAccess(namespace='target', path=path, mode='r') ) - command = 'read_tree_tgz_b64 {} {} {}'.format(quote(path), depth, - quote(self.working_directory)) - output = await self._execute_util.asyn(command, as_root=self.is_rooted, - check_exit_code=check_exit_code) + if path and self.working_directory: + command = 'read_tree_tgz_b64 {} {} {}'.format(quote(path), depth, + quote(self.working_directory)) + output: str = await self._execute_util.asyn(command, as_root=self.is_rooted, + check_exit_code=check_exit_code) - result = {} + result: Dict[str, str] = {} # Unpack the archive in memory tar_gz = base64.b64decode(output) @@ -1533,25 +2327,36 @@ async def read_tree_tar_flat(self, path, depth=1, check_exit_code=True, content = content_f.read() if decode_unicode: try: - content = content.decode('utf-8').strip() + content_str = content.decode('utf-8').strip() if strip_null_chars: - content = content.replace('\x00', '').strip() + content_str = content_str.replace('\x00', '').strip() except UnicodeDecodeError: - content = '' - - name = self.path.join(path, member.name) - result[name] = content + content_str = '' + if self.path: + name: str = self.path.join(path, member.name) + result[name] = content_str return result @asyn.asyncf - async def read_tree_values_flat(self, path, depth=1, check_exit_code=True): + async def read_tree_values_flat(self, path: str, depth: int = 1, check_exit_code: bool = True) -> Dict[str, str]: + """ + Recursively read file nodes under a given directory (e.g., sysfs) on the target, + up to the specified depth, returning a flat dictionary of file paths to contents. + + :param path: The on-target directory path to read from. + :param depth: Maximum directory depth to traverse. + :param check_exit_code: If ``True``, raises an error if the helper command fails. + :return: A dict mapping each discovered file path to the file's textual content. + + :raises TargetStableError: If the read-tree helper command fails or no content is returned. + """ self.async_manager.track_access( asyn.PathAccess(namespace='target', path=path, mode='r') ) - command = 'read_tree_values {} {}'.format(quote(path), depth) - output = await self._execute_util.asyn(command, as_root=self.is_rooted, - check_exit_code=check_exit_code) + command: str = 'read_tree_values {} {}'.format(quote(path), depth) + output: str = await self._execute_util.asyn(command, as_root=self.is_rooted, + check_exit_code=check_exit_code) accumulator = defaultdict(list) for entry in output.strip().split('\n'): @@ -1560,35 +2365,44 @@ async def read_tree_values_flat(self, path, depth=1, check_exit_code=True): path, value = entry.strip().split(':', 1) accumulator[path].append(value) - result = {k: '\n'.join(v).strip() for k, v in accumulator.items()} + result: Dict[str, str] = {k: '\n'.join(v).strip() for k, v in accumulator.items()} return result @asyn.asyncf - async def read_tree_values(self, path, depth=1, dictcls=dict, - check_exit_code=True, tar=False, decode_unicode=True, - strip_null_chars=True): + async def read_tree_values(self, path: str, depth: int = 1, dictcls: Type[Dict] = dict, + check_exit_code: bool = True, tar: bool = False, decode_unicode: bool = True, + strip_null_chars: bool = True) -> Union[str, Dict[str, 'Node']]: """ - Reads the content of all files under a given tree - - :path: path to the tree - :depth: maximum tree depth to read - :dictcls: type of the dict used to store the results - :check_exit_code: raise an exception if the shutil command fails - :tar: fetch the entire tree using tar rather than just the value (more - robust but slower in some use-cases) - :decode_unicode: decode the content of tar-ed files as utf-8 - :strip_null_chars: remove '\x00' chars from the content of utf-8 - decoded files - - :returns: a tree-like dict with the content of files as leafs + Recursively read all file nodes under a given directory or tar archive on the target, + building a **tree-like** structure up to the given depth. + + :param path: On-target path to read. May be a directory path or a tar file path + if ``tar=True``. + :param depth: Maximum directory depth to traverse. + :param dictcls: The dictionary class to use for constructing the tree + (defaults to the built-in :class:`dict`). + :param check_exit_code: If ``True``, raises an error if the internal helper command fails. + :param tar: If ``True``, treat ``path`` as a tar archive and read it. If ``False``, + read from a normal directory hierarchy. + :param decode_unicode: If ``True``, decode file contents (in tar mode) as UTF-8. + :param strip_null_chars: If ``True``, strip out any null characters (``\\x00``) from + decoded text. + :return: A hierarchical dictionary (or specialized mapping) containing sub-directories + and files as nested keys, or a string in some edge cases (depending on usage). + + :raises TargetStableError: If the read-tree operation fails. + :raises UnicodeDecodeError: If a file content cannot be decoded. """ if not tar: - value_map = await self.read_tree_values_flat.asyn(path, depth, check_exit_code) + value_map: Dict[str, str] = await self.read_tree_values_flat.asyn(path, depth, check_exit_code) else: value_map = await self.read_tree_tar_flat.asyn(path, depth, check_exit_code, - decode_unicode, - strip_null_chars) - return _build_path_tree(value_map, path, self.path.sep, dictcls) + decode_unicode, + strip_null_chars) + if self.path: + return _build_path_tree(value_map, path, self.path.sep, dictcls) + else: + return {} def install_module(self, mod, **params): mod = get_module(mod) @@ -1602,74 +2416,166 @@ def install_module(self, mod, **params): # internal methods @asyn.asyncf - async def _setup_scripts(self): + async def _setup_scripts(self) -> None: + """ + Install and prepare the ``shutils`` script on the target. This script provides + shell utility functions that may be invoked by other devlib features. + + :raises TargetStableError: + If ``busybox`` is not installed or if pushing/installing ``shutils`` fails. + :raises IOError: + If reading the local script file fails on the host system. + """ scripts = os.path.join(PACKAGE_BIN_DIRECTORY, 'scripts') shutils_ifile = os.path.join(scripts, 'shutils.in') with open(shutils_ifile) as fh: - lines = fh.readlines() + lines: List[str] = fh.readlines() with tempfile.TemporaryDirectory() as folder: - shutils_ofile = os.path.join(folder, 'shutils') + shutils_ofile: str = os.path.join(folder, 'shutils') with open(shutils_ofile, 'w') as ofile: - for line in lines: - line = line.replace("__DEVLIB_BUSYBOX__", self.busybox) - ofile.write(line) + if self.busybox: + for line in lines: + line = line.replace("__DEVLIB_BUSYBOX__", self.busybox) + ofile.write(line) self._shutils = await self.install.asyn(shutils_ofile) await self.install.asyn(os.path.join(scripts, 'devlib-signal-target')) @asyn.asyncf @call_conn - async def _execute_util(self, command, timeout=None, check_exit_code=True, as_root=False): - command = '{} sh {} {}'.format(quote(self.busybox), quote(self.shutils), command) - return await self.execute.asyn( - command, - timeout=timeout, - check_exit_code=check_exit_code, - as_root=as_root - ) + async def _execute_util(self, command: SubprocessCommand, timeout: Optional[int] = None, + check_exit_code: bool = True, as_root: bool = False) -> Optional[str]: + """ + Execute a shell utility command via the ``shutils`` script on the target. + This typically prepends the busybox and shutils script calls before your + specified command. + + :param command: The command (or SubprocessCommand) string to run. + :param timeout: Maximum number of seconds to allow for completion. If None, + an implementation-defined default is used. + :param check_exit_code: If True, raise an error when the return code is non-zero. + :param as_root: If True, attempt to run with root privileges (e.g., ``su`` + or ``sudo``). + :return: The command's output on success, or ``None`` if busybox/shutils is + unavailable. + + :raises TargetStableError: If the script is not present or the command fails + with a non-zero code (while ``check_exit_code=True``). + :raises TimeoutError: If the command runs longer than the specified timeout. + """ + if self.busybox and self.shutils: + command_str = '{} sh {} {}'.format(quote(self.busybox), quote(self.shutils), cast(str, command)) + return await self.execute.asyn( + command_str, + timeout=timeout, + check_exit_code=check_exit_code, + as_root=as_root + ) + return None - async def _extract_archive(self, path, cmd, dest=None): + async def _extract_archive(self, path: str, cmd: str, dest: Optional[str] = None) -> Optional[str]: + """ + extract files of type - + '.tar.gz', '.tar.bz', '.tar.bz2', '.tgz', '.tbz', '.tbz2' + + :param path: On-target path of the compressed archive (e.g., .tar.gz). + :param cmd: A template string for the extraction command (e.g., 'tar xf {} -C {}'). + :param dest: Optional path to a destination directory on the target + where files are extracted. If not specified, extraction occurs in + the same directory as ``path``. + :return: The directory or file path where the archive's contents were extracted, + or None if ``busybox`` or other prerequisites are missing. + + :raises TargetStableError: If extraction fails or the file/directory cannot be written. + """ cmd = '{} ' + cmd # busybox if dest: - extracted = dest + extracted: Optional[str] = dest else: - extracted = self.path.dirname(path) - cmdtext = cmd.format(quote(self.busybox), quote(path), quote(extracted)) - await self.execute.asyn(cmdtext) + if self.path: + extracted = self.path.dirname(path) + if self.busybox and extracted: + cmdtext = cmd.format(quote(self.busybox), quote(path), quote(extracted)) + await self.execute.asyn(cmdtext) return extracted - async def _extract_file(self, path, cmd, dest=None): + async def _extract_file(self, path: str, cmd: str, dest: Optional[str] = None) -> Optional[str]: + """ + Decompress a single file on the target (e.g., .gz, .bz2). + + :param path: On-target path of the compressed file. + :param cmd: The decompression command format string (e.g., 'gunzip -f {}'). + :param dest: Optional directory path on the target where the decompressed file + should be moved. If omitted, the file remains in its original directory + (with the extension removed). + :return: The path to the decompressed file after extraction, or None if + prerequisites are missing. + + :raises TargetStableError: If decompression fails or the file/directory is unwritable. + """ cmd = '{} ' + cmd # busybox - cmdtext = cmd.format(quote(self.busybox), quote(path)) - await self.execute.asyn(cmdtext) - extracted = self.path.splitext(path)[0] - if dest: - await self.execute.asyn('mv -f {} {}'.format(quote(extracted), quote(dest))) - if dest.endswith('/'): - extracted = self.path.join(dest, self.path.basename(extracted)) - else: - extracted = dest - return extracted + if self.busybox and self.path: + cmdtext: str = cmd.format(quote(self.busybox), quote(path)) + await self.execute.asyn(cmdtext) + extracted: Optional[str] = self.path.splitext(path)[0] + if dest and extracted: + await self.execute.asyn('mv -f {} {}'.format(quote(extracted), quote(dest))) + if dest.endswith('/'): + extracted = self.path.join(dest, self.path.basename(extracted)) + else: + extracted = dest + return extracted + return None - def _install_module(self, mod, params, log=True): - mod = get_module(mod) - name = mod.name - if params is None or self._modules.get(name, {}) is None: - raise TargetStableError(f'Could not load module "{name}" as it has been explicilty disabled') - else: - try: - return mod.install(self, **params) - except Exception as e: - if log: - self.logger.error(f'Module "{name}" failed to install on target: {e}') - raise + def _install_module(self, mod: Union[str, Type[Module]], + params: Dict[str, Type[Module]], log: bool = True) -> Optional[Module]: + """ + Installs a devlib module onto the target post-setup. + + :param mod: Either the module's name (string) or a Module type object. + :param params: A dictionary of parameters for initializing the module. + :param log: If True, logs errors if installation fails. + :return: The instantiated Module object if installation succeeds, otherwise None. + + :raises TargetStableError: If the module has been explicitly disabled or if + initialization fails irrecoverably. + :raises Exception: If any other unexpected error occurs. + """ + module = get_module(mod) + name = module.name + if name: + if params is None or self._modules.get(name, {}) is None: + raise TargetStableError(f'Could not load module "{name}" as it has been explicilty disabled') + else: + try: + return module.install(self, **params) + except Exception as e: + if log: + self.logger.error(f'Module "{name}" failed to install on target: {e}') + raise + raise TargetStableError('Failed to install module as module name is not present') @property - def modules(self): + def modules(self) -> List[str]: + """ + A list of module names registered on this target, regardless of which + have been installed. + + :return: Sorted list of module names. + """ return sorted(self._modules.keys()) - def _update_modules(self, stage): - to_install = [ + def _update_modules(self, stage: str) -> None: + """ + Load or install modules that match the specified stage (e.g., "early", + "connected", or "setup"). + + :param stage: The stage name used for grouping when modules should be installed. + + :raises Exception: If a module fails installation or is not supported + by the target (caught and logged internally). + """ + to_install: List[Tuple[Type[Module], Dict[str, Type[Module]]]] = [ (mod, params) for mod, params in ( (get_module(name), params) @@ -1681,10 +2587,18 @@ def _update_modules(self, stage): try: self._install_module(mod, params) except Exception as e: - mod_name = mod.name self.logger.warning(f'Module {mod.name} is not supported by the target: {e}') - def _get_module(self, modname, log=True): + def _get_module(self, modname: str, log: bool = True) -> Module: + """ + Retrieve or install a module by name. If not already installed, this + attempts to install it first. + + :param modname: The name or attribute of the module to retrieve. + :param log: If True, logs errors if installation fails. + :return: The installed module object, if successful. + :raises AttributeError: If the module or attribute cannot be found or installed. + """ try: return self._installed_modules[modname] except KeyError: @@ -1698,12 +2612,12 @@ def _get_module(self, modname, log=True): except ValueError: for _mod, _params in self._modules.items(): try: - _mod = get_module(_mod) + _module = get_module(_mod) except ValueError: pass else: - if _mod.attr_name == modname: - mod = _mod + if _module.attr_name == modname: + mod = _module params = _params break else: @@ -1711,12 +2625,22 @@ def _get_module(self, modname, log=True): f"'{self.__class__.__name__}' object has no attribute '{modname}'" ) else: - params = self._modules.get(mod.name, {}) + if mod.name: + params = self._modules.get(mod.name, {}) self._install_module(mod, params, log=log) return self.__getattr__(modname) - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Module: + """ + Fallback attribute accessor, invoked if a normal attribute or method + is not found. This checks for a corresponding installed or installable + module whose name matches ``attr``. + + :param attr: The module name or attribute to fetch. + :return: The installed module if found/installed, otherwise raises AttributeError. + :raises AttributeError: If the module does not exist or cannot be installed. + """ # When unpickled, objects will have an empty dict so fail early if attr.startswith('__') and attr.endswith('__'): raise AttributeError(attr) @@ -1728,11 +2652,28 @@ def __getattr__(self, attr): # work as expected raise AttributeError(str(e)) - def _resolve_paths(self): + def _resolve_paths(self) -> None: + """ + Perform final path resolutions, such as setting the target's working directory, + file transfer cache, or executables directory. + + :raises NotImplementedError: If the target subclass has not overridden this method. + """ raise NotImplementedError() @asyn.asyncf - async def is_network_connected(self): + async def is_network_connected(self) -> bool: + """ + Check if the target has basic network/internet connectivity by using + ``ping`` to reach a known IP (e.g., 8.8.8.8). + + :return: True if the network appears to be reachable; False otherwise. + + :raises TargetStableError: If the network is known to be unreachable or if + the shell command reports a fatal error. + :raises TimeoutError: If repeatedly pinging does not respond within + the default or user-defined time. + """ self.logger.debug('Checking for internet connectivity...') timeout_s = 5 @@ -1750,7 +2691,7 @@ async def is_network_connected(self): await self.execute.asyn(command) return True except TargetStableError as e: - err = str(e).lower() + err: str = str(e).lower() if '100% packet loss' in err: # We sent a packet but got no response. # Try again - we don't want this to fail just because of a @@ -1773,14 +2714,32 @@ async def is_network_connected(self): class LinuxTarget(Target): + """ + A specialized :class:`Target` subclass for devices or systems running Linux. + Adapts path handling to ``posixpath`` and includes additional helpers for + Linux-specific commands or filesystems. + + :ivar path: Set to ``posixpath``. + :vartype path: ModuleType + :ivar os: ``"linux"`` + :vartype os: str + """ - path = posixpath + path: ModuleType = posixpath os = 'linux' @property @memoized - def abi(self): - value = self.execute('uname -m').strip() + def abi(self) -> str: + """ + Determine the Application Binary Interface (ABI) of the device by + interpreting the output of ``uname -m`` and mapping it to known + architecture strings in ``ABI_MAP``. + + :return: The ABI string (e.g., "arm64" or "x86_64"). If unmapped, + returns the exact output of ``uname -m``. + """ + value: str = self.execute('uname -m').strip() for abi, architectures in ABI_MAP.items(): if value in architectures: result = abi @@ -1791,34 +2750,47 @@ def abi(self): @property @memoized - def os_version(self): - os_version = {} + def os_version(self) -> Dict[str, str]: + """ + Gather Linux distribution or version info by scanning files in ``/etc/`` + that end with ``-release`` or ``-version``. + + :return: A dictionary mapping the filename (e.g. "os-release") to + its contents as a single line. + """ + os_version: Dict[str, str] = {} command = 'ls /etc/*-release /etc*-version /etc/*_release /etc/*_version 2>/dev/null' - version_files = self.execute(command, check_exit_code=False).strip().split() + version_files: List[str] = self.execute(command, check_exit_code=False).strip().split() for vf in version_files: - name = self.path.basename(vf) - output = self.read_value(vf) + name: str = self.path.basename(vf) + output: str = self.read_value(vf) os_version[name] = convert_new_lines(output.strip()).replace('\n', ' ') return os_version @property @memoized - def system_id(self): + def system_id(self) -> str: + """ + Retrieve a Linux-specific system ID by invoking + a specialized utility command on the target. + + :return: A string uniquely identifying the Linux system. + """ return self._execute_util('get_linux_system_id').strip() def __init__(self, - connection_settings=None, - platform=None, - working_directory=None, - executables_directory=None, - connect=True, - modules=None, - load_default_modules=True, - shell_prompt=DEFAULT_SHELL_PROMPT, - conn_cls=SshConnection, - is_container=False, - max_async=50, - tmp_directory=None, + connection_settings: Optional[UserConnectionSettings] = None, + platform: Optional[Platform] = None, + working_directory: Optional[str] = None, + executables_directory: Optional[str] = None, + connect: bool = True, + modules: Optional[Dict[str, Dict[str, Type[Module]]]] = None, + load_default_modules: bool = True, + shell_prompt: Pattern[str] = DEFAULT_SHELL_PROMPT, + conn_cls: 'InitCheckpointMeta' = SshConnection, + is_container: bool = False, + max_async: int = 50, + tmp_directory: Optional[str] = None, ): super(LinuxTarget, self).__init__(connection_settings=connection_settings, platform=platform, @@ -1834,23 +2806,38 @@ def __init__(self, tmp_directory=tmp_directory, ) - def wait_boot_complete(self, timeout=10): + def wait_boot_complete(self, timeout: Optional[int] = 10) -> None: + """ + wait for target to boot up + """ pass @asyn.asyncf - async def get_pids_of(self, process_name): - """Returns a list of PIDs of all processes with the specified name.""" + async def get_pids_of(self, process_name) -> List[int]: + """ + Return a list of PIDs of all running processes matching the given name. + + :param process_name: Name of the process to look up. + :return: List of matching PIDs. + :raises NotImplementedError: If not overridden by child classes. + """ # result should be a column of PIDs with the first row as "PID" header - result = await self.execute.asyn('ps -C {} -o pid'.format(quote(process_name)), # NOQA - check_exit_code=False) - result = result.strip().split() + result_temp:str = await self.execute.asyn('ps -C {} -o pid'.format(quote(process_name)), # NOQA + check_exit_code=False) + result: List[str] = result_temp.strip().split() if len(result) >= 2: # at least one row besides the header return list(map(int, result[1:])) else: return [] @asyn.asyncf - async def ps(self, threads=False, **kwargs): + async def ps(self, threads: bool = False, **kwargs: Dict[str, Any]) -> List['PsEntry']: + """ + Return a list of PsEntry objects for each process on the system. + + :return: A list of processes. + :raises NotImplementedError: If not overridden. + """ ps_flags = '-eo' if threads: ps_flags = '-eLo' @@ -1858,43 +2845,63 @@ async def ps(self, threads=False, **kwargs): out = await self.execute.asyn(command) - result = [] - lines = convert_new_lines(out).splitlines() + result: List['PsEntry'] = [] + lines: List[str] = convert_new_lines(out).splitlines() # Skip header for line in lines[1:]: - parts = re.split(r'\s+', line, maxsplit=9) + parts: List[str] = re.split(r'\s+', line, maxsplit=9) if parts: result.append(PsEntry(*(parts[0:1] + list(map(int, parts[1:6])) + parts[6:]))) if not kwargs: return result else: - filtered_result = [] + filtered_result: List['PsEntry'] = [] for entry in result: if all(getattr(entry, k) == v for k, v in kwargs.items()): filtered_result.append(entry) return filtered_result - async def _list_directory(self, path, as_root=False): + async def _list_directory(self, path: str, as_root: bool = False) -> List[str]: + """ + target specific implementation of list_directory + """ contents = await self.execute.asyn('ls -1 {}'.format(quote(path)), as_root=as_root) return [x.strip() for x in contents.split('\n') if x.strip()] @asyn.asyncf - async def install(self, filepath, timeout=None, with_name=None): # pylint: disable=W0221 - destpath = self.path.join(self.executables_directory, - with_name and with_name or self.path.basename(filepath)) + async def install(self, filepath: str, timeout: Optional[int] = None, + with_name: Optional[str] = None) -> str: # pylint: disable=W0221 + """ + Install an executable on the device. + + :param filepath: path to the executable on the host + :param timeout: Optional timeout (in seconds) for the installation + :param with_name: This may be used to rename the executable on the target + """ + destpath: str = self.path.join(self.executables_directory, + with_name and with_name or self.path.basename(filepath)) await self.push.asyn(filepath, destpath, timeout=timeout) await self.execute.asyn('chmod a+x {}'.format(quote(destpath)), timeout=timeout) self._installed_binaries[self.path.basename(destpath)] = destpath return destpath @asyn.asyncf - async def uninstall(self, name): - path = self.path.join(self.executables_directory, name) + async def uninstall(self, name: str) -> None: + """ + Uninstall the specified executable from the target + """ + path: str = self.path.join(self.executables_directory, name) await self.remove.asyn(path) @asyn.asyncf - async def capture_screen(self, filepath): + async def capture_screen(self, filepath: str) -> None: + """ + Take a screenshot on the device and save it to the specified file on the + host. This may not be supported by the target. You can optionally insert a + ``{ts}`` tag into the file name, in which case it will be substituted with + on-target timestamp of the screen shot in ISO8601 format. + """ if not (await self.is_installed.asyn('scrot')): self.logger.debug('Could not take screenshot as scrot is not installed.') return @@ -1902,37 +2909,82 @@ async def capture_screen(self, filepath): async with self.make_temp(is_directory=False) as tmpfile: cmd = 'DISPLAY=:0.0 scrot {} && {} date -u -Iseconds' - ts = (await self.execute.asyn(cmd.format(quote(tmpfile), quote(self.busybox)))).strip() - filepath = filepath.format(ts=ts) - await self.pull.asyn(tmpfile, filepath) + if self.busybox: + ts: str = (await self.execute.asyn(cmd.format(quote(tmpfile), quote(self.busybox)))).strip() + filepath = filepath.format(ts=ts) + await self.pull.asyn(tmpfile, filepath) + else: + raise TargetStableError("busybox is not present") except TargetStableError as e: - if "Can't open X dispay." not in e.message: + if isinstance(e.message, str) and "Can't open X dispay." not in e.message: raise e - message = e.message.split('OUTPUT:', 1)[1].strip() # pylint: disable=no-member - self.logger.debug('Could not take screenshot: {}'.format(message)) + if isinstance(e.message, str): + message = e.message.split('OUTPUT:', 1)[1].strip() # pylint: disable=no-member + self.logger.debug('Could not take screenshot: {}'.format(message)) - def _resolve_paths(self): + def _resolve_paths(self) -> None: + """ + set paths for working directory, file transfer cache and executables directory + """ if self.working_directory is None: # This usually lands in the home directory self.working_directory = self.path.join(self.execute("pwd").strip(), 'devlib-target') class AndroidTarget(Target): - + """ + A specialized :class:`Target` subclass for devices running Android. This + provides additional Android-specific features like property retrieval + (``getprop``), APK installation, ADB connection management, screen controls, + input injection, and more. + + :param connection_settings: Parameters for connecting to the device + (e.g., ADB serial or host/port). + :param platform: A ``Platform`` object describing hardware aspects. If None, + a generic or default platform is used. + :param working_directory: A directory on the device for devlib to store + temporary files. Defaults to a subfolder of external storage. + :param executables_directory: A directory on the device where devlib + installs binaries. Defaults to ``/data/local/tmp/bin``. + :param connect: If True, automatically connect to the device upon instantiation. + Otherwise, call :meth:`connect`. + :param modules: Additional modules to load (name -> parameters). + :param load_default_modules: If True, load all modules in :attr:`default_modules`. + :param shell_prompt: Regex matching the interactive shell prompt, if used. + :param conn_cls: The connection class, typically :class:`AdbConnection`. + :param package_data_directory: Location where installed packages store data. + Defaults to ``"/data/data"``. + :param is_container: If True, indicates the device is actually a container environment. + :param max_async: Maximum number of asynchronous operations to allow in parallel. + """ path = posixpath os = 'android' ls_command = '' @property @memoized - def abi(self): + def abi(self) -> str: + """ + Return the main ABI (CPU architecture) by reading ``ro.product.cpu.abi`` + from the device properties. + + :return: E.g. "arm64" or "armeabi-v7a" for an Android device. + """ return self.getprop()['ro.product.cpu.abi'].split('-')[0] @property @memoized - def supported_abi(self): - props = self.getprop() - result = [props['ro.product.cpu.abi']] + def supported_abi(self) -> List[Optional[str]]: + """ + List all supported ABIs found in Android system properties. Combines + values from ``ro.product.cpu.abi``, ``ro.product.cpu.abi2``, + and ``ro.product.cpu.abilist``. + + :return: A list of ABI strings (some might be mapped to devlib’s known + architecture list). + """ + props: Dict[str, str] = self.getprop() + result: List[str] = [props['ro.product.cpu.abi']] if 'ro.product.cpu.abi2' in props: result.append(props['ro.product.cpu.abi2']) if 'ro.product.cpu.abilist' in props: @@ -1940,7 +2992,7 @@ def supported_abi(self): if abi not in result: result.append(abi) - mapped_result = [] + mapped_result: List[Optional[str]] = [] for supported_abi in result: for abi, architectures in ABI_MAP.items(): found = False @@ -1954,29 +3006,59 @@ def supported_abi(self): @property @memoized - def os_version(self): - os_version = {} + def os_version(self) -> Dict[str, str]: + """ + Read and parse Android build version info from properties whose keys + start with ``ro.build.version``. + + :return: Dictionary mapping the last component of each key + (e.g., "incremental", "release") to its string value. + """ + os_version: Dict[str, str] = {} for k, v in self.getprop().iteritems(): if k.startswith('ro.build.version'): - part = k.split('.')[-1] + part: str = k.split('.')[-1] os_version[part] = v return os_version @property - def adb_name(self): + def adb_name(self) -> Optional[str]: + """ + The ADB device name or serial number for the connected Android device. + + :return: + - The string serial/ID if connected via ADB (e.g. ``"0123456789ABCDEF"``). + - ``None`` if unavailable or a different connection type is used (e.g. SSH). + """ return getattr(self.conn, 'device', None) @property - def adb_server(self): + def adb_server(self) -> Optional[str]: + """ + The hostname or IP address of the ADB server, if using a remote ADB + connection. + + :return: + - The ADB server address (e.g. ``"127.0.0.1"``). + - ``None`` if not applicable (local ADB or a non-ADB connection). + """ return getattr(self.conn, 'adb_server', None) @property - def adb_port(self): + def adb_port(self) -> Optional[int]: + """ + The TCP port on which the ADB server is listening, if using a remote ADB + connection. + + :return: + - An integer port number (e.g. 5037). + - ``None`` if not applicable or unknown. + """ return getattr(self.conn, 'adb_port', None) @property @memoized - def android_id(self): + def android_id(self) -> str: """ Get the device's ANDROID_ID. Which is @@ -1986,30 +3068,69 @@ def android_id(self): .. note:: This will get reset on userdata erasure. + :return: The ANDROID_ID in hexadecimal form. + """ + # FIXME - would it be better to just do 'settings get secure android_id' ? when trying to execute the content command, + # getting some access issues with settings output = self.execute('content query --uri content://settings/secure --projection value --where "name=\'android_id\'"').strip() return output.split('value=')[-1] @property @memoized - def system_id(self): + def system_id(self) -> str: + """ + Obtain a unique Android system identifier by using a device utility + (e.g., 'get_android_system_id' in shutils). + + :return: A device-specific ID string. + """ return self._execute_util('get_android_system_id').strip() @property @memoized - def external_storage(self): + def external_storage(self) -> str: + """ + The path to the device's external storage directory (often ``/sdcard`` or + ``/storage/emulated/0``). + + :return: + A filesystem path pointing to the shared/SD card area on the Android device. + :raises TargetStableError: + If the environment variable ``EXTERNAL_STORAGE`` is unset or an error + occurs reading it. + """ return self.execute('echo $EXTERNAL_STORAGE').strip() @property @memoized - def external_storage_app_dir(self): - return self.path.join(self.external_storage, 'Android', 'data') + def external_storage_app_dir(self) -> Optional[str]: + """ + The application-specific directory within external storage + (commonly ``/sdcard/Android/data``). + + :return: + The path to the app-specific directory under external storage, or + ``None`` if not determinable (e.g. no external storage). + """ + if self.path: + return self.path.join(self.external_storage, 'Android', 'data') + return None @property @memoized - def screen_resolution(self): - output = self.execute('dumpsys window displays') - match = ANDROID_SCREEN_RESOLUTION_REGEX.search(output) + def screen_resolution(self) -> Tuple[int, int]: + """ + The current display resolution (width, height), read from ``dumpsys window displays``. + + :return: + A tuple ``(width, height)`` of the device’s screen resolution in pixels. + + :raises TargetStableError: + If the resolution cannot be parsed from ``dumpsys`` output. + """ + output: str = self.execute('dumpsys window displays') + match: Optional[Match[str]] = ANDROID_SCREEN_RESOLUTION_REGEX.search(output) if match: return (int(match.group('width')), int(match.group('height'))) @@ -2017,20 +3138,24 @@ def screen_resolution(self): return (0, 0) def __init__(self, - connection_settings=None, - platform=None, - working_directory=None, - executables_directory=None, - connect=True, - modules=None, - load_default_modules=True, - shell_prompt=DEFAULT_SHELL_PROMPT, - conn_cls=AdbConnection, - package_data_directory="/data/data", - is_container=False, - max_async=50, - tmp_directory=None, + connection_settings: Optional[UserConnectionSettings] = None, + platform: Optional[Platform] = None, + working_directory: Optional[str] = None, + executables_directory: Optional[str] = None, + connect: bool = True, + modules: Optional[Dict[str, Dict[str, Type[Module]]]] = None, + load_default_modules: bool = True, + shell_prompt: Pattern[str] = DEFAULT_SHELL_PROMPT, + conn_cls: 'InitCheckpointMeta' = AdbConnection, + package_data_directory: str = "/data/data", + is_container: bool = False, + max_async: int = 50, + tmp_directory: Optional[str] = None, ): + """ + Initialize an AndroidTarget instance and optionally connect to the + device via ADB. + """ super(AndroidTarget, self).__init__(connection_settings=connection_settings, platform=platform, working_directory=working_directory, @@ -2047,10 +3172,17 @@ def __init__(self, self.package_data_directory = package_data_directory self._init_logcat_lock() - def _init_logcat_lock(self): + def _init_logcat_lock(self) -> None: + """ + Initialize a lock used for serializing logcat clearing operations. + This prevents overlapping ``logcat -c`` calls from multiple threads. + """ self.clear_logcat_lock = threading.Lock() - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: + """ + Extend the base pickling to skip the `clear_logcat_lock`. + """ dct = super().__getstate__() return { k: v @@ -2058,35 +3190,64 @@ def __getstate__(self): if k not in ('clear_logcat_lock',) } - def __setstate__(self, dct): + def __setstate__(self, dct: Dict[str, Any]) -> None: + """ + Restore post-pickle state, reinitializing the logcat lock. + """ super().__setstate__(dct) self._init_logcat_lock() @asyn.asyncf async def reset(self, fastboot=False): # pylint: disable=arguments-differ + """ + Soft reset (reboot) the device. If ``fastboot=True``, attempt to reboot + into fastboot mode. + + :param fastboot: If True, reboot into fastboot instead of normal reboot. + :raises DevlibTransientError: If "reboot" command fails or times out. + """ try: await self.execute.asyn('reboot {}'.format(fastboot and 'fastboot' or ''), - as_root=self.needs_su, timeout=2) + as_root=self.needs_su, timeout=2) except (DevlibTransientError, subprocess.CalledProcessError): # on some targets "reboot" doesn't return gracefully pass self.conn.connected_as_root = None @asyn.asyncf - async def wait_boot_complete(self, timeout=10): - start = time.time() - boot_completed = boolean(await self.getprop.asyn('sys.boot_completed')) - while not boot_completed and timeout >= time.time() - start: - time.sleep(5) - boot_completed = boolean(await self.getprop.asyn('sys.boot_completed')) - if not boot_completed: - # Raise a TargetStableError as this usually happens because of - # an issue with Android more than a timeout that is too small. - raise TargetStableError('Connected but Android did not fully boot.') + async def wait_boot_complete(self, timeout: Optional[int] = 10) -> None: + """ + Wait for Android to finish booting, typically by polling ``sys.boot_completed`` + property. + + :param timeout: Seconds to wait. If the property isn't set by this time, raise. + :raises TargetStableError: If the device remains un-booted after `timeout` seconds. + """ + start: float = time.time() + boot_completed: bool = boolean(await self.getprop.asyn('sys.boot_completed')) + if timeout: + while not boot_completed and timeout >= time.time() - start: + time.sleep(5) + boot_completed = boolean(await self.getprop.asyn('sys.boot_completed')) + if not boot_completed: + # Raise a TargetStableError as this usually happens because of + # an issue with Android more than a timeout that is too small. + raise TargetStableError('Connected but Android did not fully boot.') @asyn.asyncf - async def connect(self, timeout=30, check_boot_completed=True, max_async=None): # pylint: disable=arguments-differ - device = self.connection_settings.get('device') + async def connect(self, timeout: Optional[int] = 30, + check_boot_completed: Optional[bool] = True, + max_async: Optional[int] = None) -> None: # pylint: disable=arguments-differ + """ + Establish a connection to the target. It is usually not necessary to call + this explicitly, as a connection gets automatically established on + instantiation. + + :param timeout: Time in seconds before giving up on connection attempts. + :param check_boot_completed: Whether to call :meth:`wait_boot_complete`. + :param max_async: Override the default concurrency limit if provided. + :raises TargetError: If the device fails to connect. + """ await super(AndroidTarget, self).connect.asyn( timeout=timeout, check_boot_completed=check_boot_completed, @@ -2094,7 +3255,12 @@ async def connect(self, timeout=30, check_boot_completed=True, max_async=None): ) @asyn.asyncf - async def __setup_list_directory(self): + async def __setup_list_directory(self) -> None: + """ + One-time setup to determine if the device supports ``ls -1``. On older + Android versions, the ``-1`` flag might not be available, so fallback + to plain ``ls``. + """ # In at least Linaro Android 16.09 (which was their first Android 7 release) and maybe # AOSP 7.0 as well, the ls command was changed. # Previous versions default to a single column listing, which is nice and easy to parse. @@ -2103,34 +3269,66 @@ async def __setup_list_directory(self): # so we try the new version, and if it fails we use the old version. self.ls_command = 'ls -1' try: - await self.execute.asyn('ls -1 {}'.format(quote(self.working_directory)), as_root=False) + await self.execute.asyn('ls -1 {}'.format(quote(self.working_directory or '')), as_root=False) except TargetStableError: self.ls_command = 'ls' - async def _list_directory(self, path, as_root=False): + async def _list_directory(self, path: str, as_root: bool = False) -> List[str]: + """ + Implementation of :meth:`list_directory` for Android. Uses an ls command + that might be adjusted depending on OS version. + + :param path: Directory path on the device. + :param as_root: If True, escalate privileges for listing. + :return: A list of file/directory names in the specified path. + :raises TargetStableError: If the directory doesn't exist or can't be listed. + """ if self.ls_command == '': await self.__setup_list_directory.asyn() contents = await self.execute.asyn('{} {}'.format(self.ls_command, quote(path)), as_root=as_root) return [x.strip() for x in contents.split('\n') if x.strip()] @asyn.asyncf - async def install(self, filepath, timeout=None, with_name=None): # pylint: disable=W0221 - ext = os.path.splitext(filepath)[1].lower() + async def install(self, filepath: str, timeout: Optional[int] = None, + with_name: Optional[str] = None) -> str: # pylint: disable=W0221 + """ + Install a file (APK or binary) onto the Android device. If the file is an APK, + use :meth:`install_apk`; otherwise, use :meth:`install_executable`. + + :param filepath: Path on the host to the file (APK or binary). + :param timeout: Optional time in seconds to allow the install. + :param with_name: If installing a binary, rename it on the device. Ignored for APKs. + :return: The path or package installed on the device. + :raises TargetStableError: If the file extension is unsupported or installation fails. + """ + ext: str = os.path.splitext(filepath)[1].lower() if ext == '.apk': return await self.install_apk.asyn(filepath, timeout) else: return await self.install_executable.asyn(filepath, with_name, timeout) @asyn.asyncf - async def uninstall(self, name): + async def uninstall(self, name: str) -> None: + """ + Uninstall either a package (if installed as an APK) or an executable from + the device. + + :param name: The package name or binary name to remove. + """ if await self.package_is_installed.asyn(name): await self.uninstall_package.asyn(name) else: await self.uninstall_executable.asyn(name) @asyn.asyncf - async def get_pids_of(self, process_name): - result = [] + async def get_pids_of(self, process_name: str) -> List[int]: + """ + Return a list of process IDs (PIDs) for any processes matching ``process_name``. + + :param process_name: The substring or name to search for in the command name. + :return: List of integer PIDs matching the name. + """ + result: List[int] = [] search_term = process_name[-15:] for entry in await self.ps.asyn(): if search_term in entry.name: @@ -2138,7 +3336,17 @@ async def get_pids_of(self, process_name): return result @asyn.asyncf - async def ps(self, threads=False, **kwargs): + async def ps(self, threads: bool = False, **kwargs: Dict[str, Any]) -> List['PsEntry']: + """ + Return a list of process entries on the device (like ``ps`` output), + optionally including thread info if ``threads=True``. + + :param threads: If True, use ``ps -AT`` to include threads. + :param kwargs: Key/value filters to match against the returned attributes + (like user, name, etc.). + :return: A list of PsEntry objects matching the filter. + :raises TargetStableError: If the command fails or ps output is malformed. + """ maxsplit = 9 if threads else 8 command = 'ps' if threads: @@ -2146,13 +3354,13 @@ async def ps(self, threads=False, **kwargs): lines = iter(convert_new_lines(await self.execute.asyn(command)).split('\n')) next(lines) # header - result = [] + result: List['PsEntry'] = [] for line in lines: - parts = line.split(None, maxsplit) + parts: List[str] = line.split(None, maxsplit) if not parts: continue - wchan_missing = False + wchan_missing: bool = False if len(parts) == maxsplit: wchan_missing = True @@ -2167,30 +3375,56 @@ async def ps(self, threads=False, **kwargs): if not kwargs: return result else: - filtered_result = [] + filtered_result: List['PsEntry'] = [] for entry in result: if all(getattr(entry, k) == v for k, v in kwargs.items()): filtered_result.append(entry) return filtered_result @asyn.asyncf - async def capture_screen(self, filepath): - on_device_file = self.path.join(self.working_directory, 'screen_capture.png') - cmd = 'screencap -p {} && {} date -u -Iseconds' - ts = (await self.execute.asyn(cmd.format(quote(on_device_file), quote(self.busybox)))).strip() - filepath = filepath.format(ts=ts) - await self.pull.asyn(on_device_file, filepath) - await self.remove.asyn(on_device_file) + async def capture_screen(self, filepath: str) -> None: + """ + Take a screenshot on the device and save it to the specified file on the + host. This may not be supported by the target. You can optionally insert a + ``{ts}`` tag into the file name, in which case it will be substituted with + on-target timestamp of the screen shot in ISO8601 format. + + :param filepath: The host file path to store the screenshot. E.g. + ``"my_screenshot_{ts}.png"`` + :raises TargetStableError: If the device lacks a necessary screenshot tool (e.g. screencap). + """ + if self.path and self.working_directory: + on_device_file: str = self.path.join(self.working_directory, 'screen_capture.png') + cmd = 'screencap -p {} && {} date -u -Iseconds' + if self.busybox: + ts = (await self.execute.asyn(cmd.format(quote(on_device_file), quote(self.busybox)))).strip() + filepath = filepath.format(ts=ts) + await self.pull.asyn(on_device_file, filepath) + await self.remove.asyn(on_device_file) # Android-specific @asyn.asyncf - async def input_tap(self, x, y): + async def input_tap(self, x: int, y: int) -> None: + """ + Simulate a tap/click event at (x, y) on the device screen. + + :param x: The horizontal coordinate (pixels). + :param y: The vertical coordinate (pixels). + :raises TargetStableError: If the ``input`` command is not found or fails. + """ command = 'input tap {} {}' await self.execute.asyn(command.format(x, y)) @asyn.asyncf - async def input_tap_pct(self, x, y): + async def input_tap_pct(self, x: int, y: int): + """ + Simulate a tap event using percentage-based coordinates, relative + to the device screen size. + + :param x: Horizontal position as a percentage of screen width (0 to 100). + :param y: Vertical position as a percentage of screen height (0 to 100). + """ width, height = self.screen_resolution x = (x * width) // 100 @@ -2199,19 +3433,27 @@ async def input_tap_pct(self, x, y): await self.input_tap.asyn(x, y) @asyn.asyncf - async def input_swipe(self, x1, y1, x2, y2): + async def input_swipe(self, x1: int, y1: int, x2: int, y2: int) -> None: """ - Issue a swipe on the screen from (x1, y1) to (x2, y2) - Uses absolute screen positions + Issue a swipe gesture from (x1, y1) to (x2, y2), using absolute pixel coordinates. + + :param x1: Start X coordinate in pixels. + :param y1: Start Y coordinate in pixels. + :param x2: End X coordinate in pixels. + :param y2: End Y coordinate in pixels. """ command = 'input swipe {} {} {} {}' await self.execute.asyn(command.format(x1, y1, x2, y2)) @asyn.asyncf - async def input_swipe_pct(self, x1, y1, x2, y2): + async def input_swipe_pct(self, x1: int, y1: int, x2: int, y2: int) -> None: """ - Issue a swipe on the screen from (x1, y1) to (x2, y2) - Uses percent-based positions + Issue a swipe gesture from (x1, y1) to (x2, y2) using percentage-based coordinates. + + :param x1: Horizontal start percentage (0-100). + :param y1: Vertical start percentage (0-100). + :param x2: Horizontal end percentage (0-100). + :param y2: Vertical end percentage (0-100). """ width, height = self.screen_resolution @@ -2223,7 +3465,14 @@ async def input_swipe_pct(self, x1, y1, x2, y2): await self.input_swipe.asyn(x1, y1, x2, y2) @asyn.asyncf - async def swipe_to_unlock(self, direction="diagonal"): + async def swipe_to_unlock(self, direction: str = "diagonal") -> None: + """ + Attempt to swipe the lock screen open. Common directions are ``"horizontal"``, + ``"vertical"``, or ``"diagonal"``. + + :param direction: The direction to swipe; defaults to diagonal for maximum coverage. + :raises TargetStableError: If the direction is invalid or the swipe fails. + """ width, height = self.screen_resolution if direction == "diagonal": start = 100 @@ -2236,21 +3485,38 @@ async def swipe_to_unlock(self, direction="diagonal"): stop = width - start await self.input_swipe.asyn(start, swipe_height, stop, swipe_height) elif direction == "vertical": - swipe_middle = width / 2 + swipe_middle = width // 2 swipe_height = height * 2 // 3 await self.input_swipe.asyn(swipe_middle, swipe_height, swipe_middle, 0) else: raise TargetStableError("Invalid swipe direction: {}".format(direction)) @asyn.asyncf - async def getprop(self, prop=None): + async def getprop(self, prop: Optional[str] = None) -> Optional[Union[str, AndroidProperties]]: + """ + Fetch properties from Android's ``getprop``. If ``prop`` is given, + return just that property's value; otherwise return a dictionary-like + :class:`AndroidProperties`. + + :param prop: A specific property key to retrieve (e.g. "ro.build.version.sdk"). + :return: + - If ``prop`` is None, a dictionary-like object mapping all property keys to values. + - If ``prop`` is non-empty, the string value of that specific property. + """ props = AndroidProperties(await self.execute.asyn('getprop')) if prop: return props[prop] return props @asyn.asyncf - async def capture_ui_hierarchy(self, filepath): + async def capture_ui_hierarchy(self, filepath: str) -> None: + """ + Capture the current UI hierarchy via ``uiautomator dump``, pull it to + the host, and optionally format it with pretty XML. + + :param filepath: The host file path to save the UI hierarchy XML. + :raises TargetStableError: If the device cannot produce a dump or fails to store it. + """ on_target_file = self.get_workpath('screen_capture.xml') try: await self.execute.asyn('uiautomator dump {}'.format(on_target_file)) @@ -2258,26 +3524,47 @@ async def capture_ui_hierarchy(self, filepath): finally: await self.remove.asyn(on_target_file) - parsed_xml = xml.dom.minidom.parse(filepath) + parsed_xml: Document = xml.dom.minidom.parse(filepath) with open(filepath, 'w') as f: f.write(parsed_xml.toprettyxml()) @asyn.asyncf - async def is_installed(self, name): + async def is_installed(self, name: str) -> bool: + """ + Returns ``True`` if an executable with the specified name is installed on the + target and ``False`` other wise. + """ return (await super(AndroidTarget, self).is_installed.asyn(name)) or (await self.package_is_installed.asyn(name)) @asyn.asyncf - async def package_is_installed(self, package_name): + async def package_is_installed(self, package_name: str) -> bool: + """ + Check if the given package name is installed on the device. + + :param package_name: Name of the Android package (e.g. "com.example.myapp"). + :return: True if installed, False otherwise. + """ return package_name in (await self.list_packages.asyn()) @asyn.asyncf - async def list_packages(self): - output = await self.execute.asyn('pm list packages') + async def list_packages(self) -> List[str]: + """ + Return a list of installed package names on the device (via ``pm list packages``). + + :return: A list of package identifiers. + """ + output: str = await self.execute.asyn('pm list packages') output = output.replace('package:', '') return output.split() @asyn.asyncf - async def get_package_version(self, package): + async def get_package_version(self, package: str) -> Optional[str]: + """ + Obtain the versionName for a given package by parsing ``dumpsys package``. + + :param package: The package name (e.g. "com.example.myapp"). + :return: The versionName string if found, otherwise None. + """ output = await self.execute.asyn('dumpsys package {}'.format(quote(package))) for line in convert_new_lines(output).split('\n'): if 'versionName' in line: @@ -2285,27 +3572,51 @@ async def get_package_version(self, package): return None @asyn.asyncf - async def get_package_info(self, package): - output = await self.execute.asyn('pm list packages -f {}'.format(quote(package))) + async def get_package_info(self, package: str) -> Optional['installed_package_info']: + """ + Return a tuple (apk_path, package_name) for the installed package, or None if not found. + + :param package: The package identifier (e.g. "com.example.myapp"). + :return: A namedtuple with fields (apk_path, package), or None. + """ + output: str = await self.execute.asyn('pm list packages -f {}'.format(quote(package))) for entry in output.strip().split('\n'): rest, entry_package = entry.rsplit('=', 1) if entry_package != package: continue _, apk_path = rest.split(':') return installed_package_info(apk_path, entry_package) + return None @asyn.asyncf - async def get_sdk_version(self): + async def get_sdk_version(self) -> Optional[int]: + """ + Return the integer value of ``ro.build.version.sdk`` if parseable; None if not. + + :return: e.g. 29 for Android 10, or None on error. + """ try: return int(await self.getprop.asyn('ro.build.version.sdk')) except (ValueError, TypeError): return None @asyn.asyncf - async def install_apk(self, filepath, timeout=None, replace=False, allow_downgrade=False): # pylint: disable=W0221 - ext = os.path.splitext(filepath)[1].lower() + async def install_apk(self, filepath: str, timeout: Optional[int] = None, replace: Optional[bool] = False, + allow_downgrade: Optional[bool] = False) -> Optional[str]: # pylint: disable=W0221 + """ + Install an APK onto the device. If the device is connected via AdbConnection, + use an ADB install command. Otherwise, push it and run 'pm install'. + + :param filepath: The path to the APK on the host. + :param timeout: The time in seconds to wait for installation. + :param replace: If True, pass -r to 'pm install' or `adb install`. + :param allow_downgrade: If True, allow installing an older version over a newer one. + :return: The output from the install command, or None if something unexpected occurs. + :raises TargetStableError: If the file is not an APK or installation fails. + """ + ext: str = os.path.splitext(filepath)[1].lower() if ext == '.apk': - flags = [] + flags: List[str] = [] if replace: flags.append('-r') # Replace existing APK if allow_downgrade: @@ -2319,80 +3630,121 @@ async def install_apk(self, filepath, timeout=None, replace=False, allow_downgra timeout=timeout, adb_server=self.adb_server, adb_port=self.adb_port) else: - dev_path = self.get_workpath(filepath.rsplit(os.path.sep, 1)[-1]) + dev_path: Optional[str] = self.get_workpath(filepath.rsplit(os.path.sep, 1)[-1]) await self.push.asyn(quote(filepath), dev_path, timeout=timeout) - result = await self.execute.asyn("pm install {} {}".format(' '.join(flags), quote(dev_path)), timeout=timeout) - await self.remove.asyn(dev_path) - return result + if dev_path: + result: str = await self.execute.asyn("pm install {} {}".format(' '.join(flags), quote(dev_path)), timeout=timeout) + await self.remove.asyn(dev_path) + return result + else: + raise TargetStableError('Can\'t install. could not get dev path') else: raise TargetStableError('Can\'t install {}: unsupported format.'.format(filepath)) @asyn.asyncf - async def grant_package_permission(self, package, permission): + async def grant_package_permission(self, package: str, permission: str) -> None: + """ + Run `pm grant `. Ignores some errors if the permission + cannot be granted. This is typically used for runtime permissions on modern Android. + + :param package: The target package. + :param permission: The permission string to grant (e.g. "android.permission.READ_LOGS"). + :raises TargetStableError: If some unexpected error occurs that is not a known ignorable case. + """ try: return await self.execute.asyn('pm grant {} {}'.format(quote(package), quote(permission))) except TargetStableError as e: - if 'is not a changeable permission type' in e.message: - pass # Ignore if unchangeable - elif 'Unknown permission' in e.message: - pass # Ignore if unknown - elif 'has not requested permission' in e.message: - pass # Ignore if not requested - elif 'Operation not allowed' in e.message: - pass # Ignore if not allowed - elif 'is managed by role' in e.message: - pass # Ignore if cannot be granted + if isinstance(e.message, str): + if 'is not a changeable permission type' in e.message: + pass # Ignore if unchangeable + elif 'Unknown permission' in e.message: + pass # Ignore if unknown + elif 'has not requested permission' in e.message: + pass # Ignore if not requested + elif 'Operation not allowed' in e.message: + pass # Ignore if not allowed + elif 'is managed by role' in e.message: + pass # Ignore if cannot be granted + else: + raise else: raise @asyn.asyncf - async def refresh_files(self, file_list): + async def refresh_files(self, file_list: List[str]) -> None: """ - Depending on the android version and root status, determine the - appropriate method of forcing a re-index of the mediaserver cache for a given - list of files. + Attempt to force a re-index of the device media scanner for the given files. + On newer Android (7+), if not rooted, we fallback to scanning each file individually. + + :param file_list: A list of file paths on the device that may need indexing (e.g. new media). """ - if self.is_rooted or (await self.get_sdk_version.asyn()) < 24: # MM and below - common_path = commonprefix(file_list, sep=self.path.sep) + if self.path and (self.is_rooted or (await self.get_sdk_version.asyn()) < 24): # MM and below + common_path: str = commonprefix(file_list, sep=self.path.sep) await self.broadcast_media_mounted.asyn(common_path, self.is_rooted) else: for f in file_list: await self.broadcast_media_scan_file.asyn(f) @asyn.asyncf - async def broadcast_media_scan_file(self, filepath): + async def broadcast_media_scan_file(self, filepath: str) -> None: """ - Force a re-index of the mediaserver cache for the specified file. + Send a broadcast intent to the Android media scanner for a single file path. + + :param filepath: File path on the device to be scanned by mediaserver. """ command = 'am broadcast -a android.intent.action.MEDIA_SCANNER_SCAN_FILE -d {}' await self.execute.asyn(command.format(quote('file://' + filepath))) @asyn.asyncf - async def broadcast_media_mounted(self, dirpath, as_root=False): + async def broadcast_media_mounted(self, dirpath: str, as_root: bool = False) -> None: """ - Force a re-index of the mediaserver cache for the specified directory. + Broadcast that media at a directory path is newly mounted, prompting scanning + of its contents. + + :param dirpath: Directory path on the device. + :param as_root: If True, escalate privileges for the broadcast command. """ command = 'am broadcast -a android.intent.action.MEDIA_MOUNTED -d {} '\ '-n com.android.providers.media/.MediaScannerReceiver' - await self.execute.asyn(command.format(quote('file://'+dirpath)), as_root=as_root) + await self.execute.asyn(command.format(quote('file://' + dirpath)), as_root=as_root) @asyn.asyncf - async def install_executable(self, filepath, with_name=None, timeout=None): + async def install_executable(self, filepath: str, with_name: Optional[str] = None, + timeout: Optional[int] = None) -> Optional[str]: + """ + Install a single executable (non-APK) onto the device. Typically places + it in :attr:`executables_directory`, making it executable with chmod. + + :param filepath: The path on the host to the binary. + :param with_name: Optional name to rename the binary on the device. + :param timeout: Time in seconds to allow the push & setup. + :return: Path to the installed binary on the device, or None on failure. + :raises TargetStableError: If the push or setup steps fail. + """ self._ensure_executables_directory_is_writable() - executable_name = with_name or os.path.basename(filepath) - on_device_file = self.path.join(self.working_directory, executable_name) - on_device_executable = self.path.join(self.executables_directory, executable_name) - await self.push.asyn(filepath, on_device_file, timeout=timeout) - if on_device_file != on_device_executable: - await self.execute.asyn('cp -f -- {} {}'.format(quote(on_device_file), quote(on_device_executable)), - as_root=self.needs_su, timeout=timeout) - await self.remove.asyn(on_device_file, as_root=self.needs_su) - await self.execute.asyn("chmod 0777 {}".format(quote(on_device_executable)), as_root=self.needs_su) - self._installed_binaries[executable_name] = on_device_executable - return on_device_executable - - @asyn.asyncf - async def uninstall_package(self, package): + executable_name: str = with_name or os.path.basename(filepath) + if self.path: + on_device_file: str = self.path.join(self.working_directory or '', executable_name) + on_device_executable: str = self.path.join(self.executables_directory or '', executable_name) + await self.push.asyn(filepath, on_device_file, timeout=timeout) + if on_device_file != on_device_executable: + await self.execute.asyn('cp -f -- {} {}'.format(quote(on_device_file), quote(on_device_executable)), + as_root=self.needs_su, timeout=timeout) + await self.remove.asyn(on_device_file, as_root=self.needs_su) + await self.execute.asyn("chmod 0777 {}".format(quote(on_device_executable)), as_root=self.needs_su) + self._installed_binaries[executable_name] = on_device_executable + return on_device_executable + else: + raise TargetStableError('path is not assigned') + + @asyn.asyncf + async def uninstall_package(self, package: str) -> None: + """ + Uninstall an Android package by name (using ``adb uninstall`` or + ``pm uninstall``). + + :param package: The package name to remove. + """ if isinstance(self.conn, AdbConnection): adb_command(self.adb_name, "uninstall {}".format(quote(package)), timeout=30, adb_server=self.adb_server, adb_port=self.adb_port) @@ -2400,14 +3752,30 @@ async def uninstall_package(self, package): await self.execute.asyn("pm uninstall {}".format(quote(package)), timeout=30) @asyn.asyncf - async def uninstall_executable(self, executable_name): - on_device_executable = self.path.join(self.executables_directory, executable_name) - self._ensure_executables_directory_is_writable() - await self.remove.asyn(on_device_executable, as_root=self.needs_su) + async def uninstall_executable(self, executable_name: str) -> None: + """ + Remove an installed executable from :attr:`executables_directory`. + + :param executable_name: The name of the binary to remove. + """ + if self.path: + on_device_executable = self.path.join(self.executables_directory or '', executable_name) + self._ensure_executables_directory_is_writable() + await self.remove.asyn(on_device_executable, as_root=self.needs_su) @asyn.asyncf - async def dump_logcat(self, filepath, filter=None, logcat_format=None, append=False, - timeout=60): # pylint: disable=redefined-builtin + async def dump_logcat(self, filepath: str, filter: Optional[str] = None, + logcat_format: Optional[str] = None, + append: bool = False, timeout: int = 60) -> None: # pylint: disable=redefined-builtin + """ + Collect logcat output from the device and save it to ``filepath`` on the host. + + :param filepath: The file on the host to store the log output. + :param filter: If provided, a filter specifying which tags to match (e.g. '-s MyTag'). + :param logcat_format: Logcat format (e.g., 'threadtime'), if any. + :param append: If True, append to the host file instead of overwriting. + :param timeout: How many seconds to allow for reading the log. + """ op = '>>' if append else '>' filtstr = ' -s {}'.format(quote(filter)) if filter else '' formatstr = ' -v {}'.format(quote(logcat_format)) if logcat_format else '' @@ -2418,13 +3786,17 @@ async def dump_logcat(self, filepath, filter=None, logcat_format=None, append=Fa adb_port=self.adb_port) else: dev_path = self.get_workpath('logcat') - command = 'logcat {} {} {}'.format(logcat_opts, op, quote(dev_path)) - await self.execute.asyn(command, timeout=timeout) - await self.pull.asyn(dev_path, filepath) - await self.remove.asyn(dev_path) + if dev_path: + command = 'logcat {} {} {}'.format(logcat_opts, op, quote(dev_path)) + await self.execute.asyn(command, timeout=timeout) + await self.pull.asyn(dev_path, filepath) + await self.remove.asyn(dev_path) @asyn.asyncf - async def clear_logcat(self): + async def clear_logcat(self) -> None: + """ + Clear the device's logcat (``logcat -c``). Uses a lock to avoid concurrency issues. + """ locked = self.clear_logcat_lock.acquire(blocking=False) if locked: try: @@ -2436,26 +3808,59 @@ async def clear_logcat(self): finally: self.clear_logcat_lock.release() - def get_logcat_monitor(self, regexps=None): + def get_logcat_monitor(self, regexps: Optional[List[str]] = None) -> LogcatMonitor: + """ + Create a :class:`LogcatMonitor` object for capturing logcat output from the device. + + :param regexps: An optional list of uncompiled regex strings to filter log entries. + :return: A new LogcatMonitor instance referencing this AndroidTarget. + """ return LogcatMonitor(self, regexps) @call_conn - def wait_for_device(self, timeout=30): - self.conn.wait_for_device() + def wait_for_device(self, timeout: int = 30) -> None: + """ + Instruct ADB to wait until the device is present (``adb wait-for-device``). + + :param timeout: Seconds to wait before failing. + :raises TargetStableError: If waiting times out or if the connection is not ADB. + """ + if isinstance(self.conn, AdbConnection): + self.conn.wait_for_device() @call_conn - def reboot_bootloader(self, timeout=30): - self.conn.reboot_bootloader() + def reboot_bootloader(self, timeout: int = 30) -> None: + """ + Reboot the device into fastboot/bootloader mode. + + :param timeout: Time in seconds to allow for device to transition. + :raises TargetStableError: If not using ADB or the command fails. + """ + if isinstance(self.conn, AdbConnection): + self.conn.reboot_bootloader() @asyn.asyncf - async def is_screen_locked(self): + async def is_screen_locked(self) -> bool: + """ + Determine if the lock screen is active (e.g., phone is locked). + + :return: True if the screen is locked, False otherwise. + """ screen_state = await self.execute.asyn('dumpsys window') return 'mDreamingLockscreen=true' in screen_state @asyn.asyncf - async def is_screen_on(self): - output = await self.execute.asyn('dumpsys power') - match = ANDROID_SCREEN_STATE_REGEX.search(output) + async def is_screen_on(self) -> bool: + """ + Check if the device screen is currently on. + + :return: + - True if the screen is on or in certain "doze" states. + - False if the screen is off or fully asleep. + :raises TargetStableError: If unable to parse display power state. + """ + output: str = await self.execute.asyn('dumpsys power') + match: Optional[Match[str]] = ANDROID_SCREEN_STATE_REGEX.search(output) if match: if 'DOZE' in match.group(1).upper(): return True @@ -2470,19 +3875,51 @@ async def is_screen_on(self): raise TargetStableError('Could not establish screen state.') @asyn.asyncf - async def ensure_screen_is_on(self, verify=True): + async def ensure_screen_is_on(self, verify: bool = True) -> None: + """ + If the screen is off, press the power button (keyevent 26) to wake it. + Optionally verify the screen is on afterwards. + + :param verify: If True, raise an error if the screen doesn't turn on. + :raises TargetStableError: If the screen is still off after the attempt. + """ if not await self.is_screen_on.asyn(): + # The adb shell input keyevent 26 command is used to + # simulate pressing the power button on an Android device. self.execute('input keyevent 26') if verify and not await self.is_screen_on.asyn(): raise TargetStableError('Display cannot be turned on.') @asyn.asyncf - async def ensure_screen_is_on_and_stays(self, verify=True, mode=7): + async def ensure_screen_is_on_and_stays(self, verify: bool = True, mode: int = 7) -> None: + """ + Calls ``AndroidTarget.ensure_screen_is_on(verify)`` then additionally + sets the screen stay on mode to ``mode``. + mode options - + 0: Never stay on while plugged in. + 1: Stay on while plugged into an AC charger. + 2: Stay on while plugged into a USB charger. + 4: Stay on while on a wireless charger. + You can combine these values using bitwise OR. + For example, 3 (1 | 2) will stay on while plugged into either an AC or USB charger + + :param verify: If True, check that the screen does come on. + :param mode: A bitwise combination of (1 for AC, 2 for USB, 4 for wireless). + """ await self.ensure_screen_is_on.asyn(verify=verify) await self.set_stay_on_mode.asyn(mode) @asyn.asyncf - async def ensure_screen_is_off(self, verify=True): + async def ensure_screen_is_off(self, verify: bool = True) -> None: + """ + Checks if the devices screen is on and if so turns it off. + If ``verify`` is set to ``True`` then a ``TargetStableError`` + will be raise if the display cannot be turned off. E.g. if + always on mode is enabled. + + :param verify: Raise an error if the screen remains on afterwards. + :raises TargetStableError: If the display remains on due to always-on or lock states. + """ # Allow 2 attempts to help with cases of ambient display modes # where the first attempt will switch the display fully on. for _ in range(2): @@ -2490,21 +3927,38 @@ async def ensure_screen_is_off(self, verify=True): await self.execute.asyn('input keyevent 26') time.sleep(0.5) if verify and await self.is_screen_on.asyn(): - msg = 'Display cannot be turned off. Is always on display enabled?' - raise TargetStableError(msg) + msg: str = 'Display cannot be turned off. Is always on display enabled?' + raise TargetStableError(msg) @asyn.asyncf - async def set_auto_brightness(self, auto_brightness): + async def set_auto_brightness(self, auto_brightness: bool) -> None: + """ + Enable or disable automatic screen brightness. + + :param auto_brightness: True to enable auto-brightness, False to disable. + """ cmd = 'settings put system screen_brightness_mode {}' await self.execute.asyn(cmd.format(int(boolean(auto_brightness)))) @asyn.asyncf - async def get_auto_brightness(self): + async def get_auto_brightness(self) -> bool: + """ + Check if auto-brightness is enabled. + + :return: True if auto-brightness is on, False otherwise. + """ cmd = 'settings get system screen_brightness_mode' return boolean((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def set_brightness(self, value): + async def set_brightness(self, value: int) -> None: + """ + Manually set screen brightness to an integer between 0 and 255. + This also disables auto-brightness first. + + :param value: Desired brightness level (0-255). + :raises ValueError: If the given value is outside [0..255]. + """ if not 0 <= value <= 255: msg = 'Invalid brightness "{}"; Must be between 0 and 255' raise ValueError(msg.format(value)) @@ -2513,69 +3967,139 @@ async def set_brightness(self, value): await self.execute.asyn(cmd.format(int(value))) @asyn.asyncf - async def get_brightness(self): + async def get_brightness(self) -> int: + """ + Return the current screen brightness (0..255). + + :return: The brightness setting. + """ cmd = 'settings get system screen_brightness' return integer((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def set_screen_timeout(self, timeout_ms): + async def set_screen_timeout(self, timeout_ms: int) -> None: + """ + Set the screen-off timeout in milliseconds. + + :param timeout_ms: Number of ms before the screen turns off when idle. + """ cmd = 'settings put system screen_off_timeout {}' await self.execute.asyn(cmd.format(int(timeout_ms))) @asyn.asyncf - async def get_screen_timeout(self): + async def get_screen_timeout(self) -> int: + """ + Get the screen-off timeout (ms). + + :return: Milliseconds before screen turns off. + """ cmd = 'settings get system screen_off_timeout' return int((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def get_airplane_mode(self): + async def get_airplane_mode(self) -> bool: + """ + Check if airplane mode is active (global setting). + + .. note:: Requires the device to be rooted if the device is running Android 7+. + + :return: True if airplane mode is on, otherwise False. + """ cmd = 'settings get global airplane_mode_on' return boolean((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def get_stay_on_mode(self): + async def get_stay_on_mode(self) -> int: + """ + Returns an integer between ``0`` and ``7`` representing the current + stay-on mode of the device. + 0: Never stay on while plugged in. + 1: Stay on while plugged into an AC charger. + 2: Stay on while plugged into a USB charger. + 4: Stay on while on a wireless charger. + Combinations of these values can be used (e.g., 3 for both AC and USB chargers) + + :return: The integer bitmask (0..7). + """ cmd = 'settings get global stay_on_while_plugged_in' return int((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def set_airplane_mode(self, mode): - root_required = await self.get_sdk_version.asyn() > 23 + async def set_airplane_mode(self, mode: bool) -> None: + """ + Enable or disable airplane mode. On Android 7+, requires root. + + :param mode: True to enable airplane mode, False to disable. + :raises TargetStableError: If root is required but the device is not rooted. + """ + root_required: bool = await self.get_sdk_version.asyn() > 23 if root_required and not self.is_rooted: raise TargetStableError('Root is required to toggle airplane mode on Android 7+') - mode = int(boolean(mode)) + modeint = int(boolean(mode)) cmd = 'settings put global airplane_mode_on {}' - await self.execute.asyn(cmd.format(mode)) + await self.execute.asyn(cmd.format(modeint)) await self.execute.asyn('am broadcast -a android.intent.action.AIRPLANE_MODE ' - '--ez state {}'.format(mode), as_root=root_required) + '--ez state {}'.format(mode), as_root=root_required) @asyn.asyncf - async def get_auto_rotation(self): + async def get_auto_rotation(self) -> bool: + """ + Check if auto-rotation is enabled (system setting). + + :return: True if accelerometer-based rotation is enabled, False otherwise. + """ cmd = 'settings get system accelerometer_rotation' return boolean((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def set_auto_rotation(self, autorotate): + async def set_auto_rotation(self, autorotate: bool) -> None: + """ + Enable or disable auto-rotation of the screen. + + :param autorotate: True to enable, False to disable. + """ cmd = 'settings put system accelerometer_rotation {}' await self.execute.asyn(cmd.format(int(boolean(autorotate)))) @asyn.asyncf - async def set_natural_rotation(self): + async def set_natural_rotation(self) -> None: + """ + Sets the screen orientation of the device to its natural (0 degrees) + orientation. + """ await self.set_rotation.asyn(0) @asyn.asyncf - async def set_left_rotation(self): + async def set_left_rotation(self) -> None: + """ + Sets the screen orientation of the device to 90 degrees. + """ await self.set_rotation.asyn(1) @asyn.asyncf - async def set_inverted_rotation(self): + async def set_inverted_rotation(self) -> None: + """ + Sets the screen orientation of the device to its inverted (180 degrees) + orientation. + """ await self.set_rotation.asyn(2) @asyn.asyncf - async def set_right_rotation(self): + async def set_right_rotation(self) -> None: + """ + Sets the screen orientation of the device to 270 degrees. + """ await self.set_rotation.asyn(3) @asyn.asyncf - async def get_rotation(self): + async def get_rotation(self) -> Optional[int]: + """ + Returns an integer value representing the orientation of the devices + screen. ``0`` : Natural, ``1`` : Rotated Left, ``2`` : Inverted + and ``3`` : Rotated Right. + + :return: The rotation value or None if not found. + """ output = await self.execute.asyn('dumpsys input') match = ANDROID_SCREEN_ROTATION_REGEX.search(output) if match: @@ -2584,7 +4108,15 @@ async def get_rotation(self): return None @asyn.asyncf - async def set_rotation(self, rotation): + async def set_rotation(self, rotation: int) -> None: + """ + Specify an integer representing the desired screen rotation with the + following mappings: Natural: ``0``, Rotated Left: ``1``, Inverted : ``2`` + and Rotated Right : ``3``. + + :param rotation: Integer in [0..3]. + :raises ValueError: If rotation is not within [0..3]. + """ if not 0 <= rotation <= 3: raise ValueError('Rotation value must be between 0 and 3') await self.set_auto_rotation.asyn(False) @@ -2592,70 +4124,95 @@ async def set_rotation(self, rotation): await self.execute.asyn(cmd.format(rotation)) @asyn.asyncf - async def set_stay_on_never(self): + async def set_stay_on_never(self) -> None: + """ + Sets the stay-on mode to ``0``, where the screen will turn off + as standard after the timeout. + """ await self.set_stay_on_mode.asyn(0) @asyn.asyncf - async def set_stay_on_while_powered(self): + async def set_stay_on_while_powered(self) -> None: + """ + Sets the stay-on mode to ``7``, where the screen will stay on + while the device is charging + """ await self.set_stay_on_mode.asyn(7) @asyn.asyncf - async def set_stay_on_mode(self, mode): + async def set_stay_on_mode(self, mode: int) -> None: + """ + 0: Never stay on while plugged in. + 1: Stay on while plugged into an AC charger. + 2: Stay on while plugged into a USB charger. + 4: Stay on while on a wireless charger. + You can combine these values using bitwise OR. + For example, 3 (1 | 2) will stay on while plugged into either an AC or USB charger + + :param mode: Value in [0..7]. + :raises ValueError: If outside [0..7]. + """ if not 0 <= mode <= 7: raise ValueError('Screen stay on mode must be between 0 and 7') cmd = 'settings put global stay_on_while_plugged_in {}' await self.execute.asyn(cmd.format(mode)) @asyn.asyncf - async def open_url(self, url, force_new=False): + async def open_url(self, url: str, force_new: bool = False) -> None: """ - Start a view activity by specifying an URL + Launch an intent to view a given URL, optionally forcing a new task in + the activity stack. - :param url: URL of the item to display - :type url: str - - :param force_new: Force the viewing application to be relaunched - if it is already running - :type force_new: bool + :param url: URL to open (e.g. "https://www.example.com"). + :param force_new: If True, use flags to clear the existing activity stack, + forcing a fresh activity. """ cmd = 'am start -a android.intent.action.VIEW -d {}' if force_new: - cmd = cmd + ' -f {}'.format(INTENT_FLAGS['ACTIVITY_NEW_TASK'] | - INTENT_FLAGS['ACTIVITY_CLEAR_TASK']) + cmd = cmd + ' -f {}'.format(INTENT_FLAGS['ACTIVITY_NEW_TASK'] | INTENT_FLAGS['ACTIVITY_CLEAR_TASK']) await self.execute.asyn(cmd.format(quote(url))) @asyn.asyncf - async def homescreen(self): + async def homescreen(self) -> None: + """ + Return to the home screen by launching the MAIN/HOME intent. + """ await self.execute.asyn('am start -a android.intent.action.MAIN -c android.intent.category.HOME') def _resolve_paths(self): if self.working_directory is None: - self.working_directory = self.path.join(self.external_storage, 'devlib-target') + self.working_directory = self.path.join(self.external_storage, 'devlib-target') if self.path else '' if self.tmp_directory is None: # Do not rely on the generic default here, as we need to provide an # android-specific default in case it fails. try: - tmp = self.execute(f'{quote(self.busybox)} mktemp -d') + tmp = self.execute(f'{quote(self.busybox)} mktemp -d') if self.busybox else '/data/local/tmp' except Exception: tmp = '/data/local/tmp' self.tmp_directory = tmp if self.executables_directory is None: - self.executables_directory = self.path.join(self.tmp_directory, 'bin') + self.executables_directory = self.path.join(self.tmp_directory, 'bin') if self.path else '' @asyn.asyncf - async def _ensure_executables_directory_is_writable(self): - matched = [] + async def _ensure_executables_directory_is_writable(self) -> None: + """ + Check if the executables directory is on a writable mount. If not, attempt + to remount it read/write as root. + + :raises TargetStableError: If the directory cannot be remounted or found in fstab. + """ + matched: List['FstabEntry'] = [] for entry in await self.list_file_systems.asyn(): - if self.executables_directory.rstrip('/').startswith(entry.mount_point): + if self.executables_directory is not None and self.executables_directory.rstrip('/').startswith(entry.mount_point): matched.append(entry) if matched: entry = sorted(matched, key=lambda x: len(x.mount_point))[-1] if 'rw' not in entry.options: await self.execute.asyn('mount -o rw,remount {} {}'.format(quote(entry.device), - quote(entry.mount_point)), - as_root=True) + quote(entry.mount_point)), + as_root=True) else: message = 'Could not find mount point for executables directory {}' raise TargetStableError(message.format(self.executables_directory)) @@ -2663,87 +4220,125 @@ async def _ensure_executables_directory_is_writable(self): _charging_enabled_path = '/sys/class/power_supply/battery/charging_enabled' @property - def charging_enabled(self): + def charging_enabled(self) -> Optional[bool]: """ Whether drawing power to charge the battery is enabled Not all devices have the ability to enable/disable battery charging (e.g. because they don't have a battery). In that case, ``charging_enabled`` is None. + + :return: + - True if charging is enabled + - False if disabled + - None if the sysfs entry is absent """ if not self.file_exists(self._charging_enabled_path): return None return self.read_bool(self._charging_enabled_path) @charging_enabled.setter - def charging_enabled(self, enabled): + def charging_enabled(self, enabled: bool) -> None: """ Enable/disable drawing power to charge the battery Not all devices have this facility. In that case, do nothing. + + :param enabled: True to enable charging, False to disable. """ if not self.file_exists(self._charging_enabled_path): return self.write_value(self._charging_enabled_path, int(bool(enabled))) + FstabEntry = namedtuple('FstabEntry', ['device', 'mount_point', 'fs_type', 'options', 'dump_freq', 'pass_num']) PsEntry = namedtuple('PsEntry', 'user pid tid ppid vsize rss wchan pc state name') LsmodEntry = namedtuple('LsmodEntry', ['name', 'size', 'use_count', 'used_by']) class Cpuinfo(object): + """ + Represents the parsed contents of ``/proc/cpuinfo`` on the target. + :param sections: A list of dictionaries, where each dictionary represents a + block of lines corresponding to a CPU. Key-value pairs correspond to + lines like ``CPU part: 0xd03`` or ``model name: Cortex-A53``. + :param text: The full text of the original ``/proc/cpuinfo`` content. + """ @property @memoized - def architecture(self): - for section in self.sections: - if 'CPU architecture' in section: - return section['CPU architecture'] - if 'architecture' in section: - return section['architecture'] + def architecture(self) -> Optional[str]: + """ + architecture as per cpuinfo + """ + if self.sections: + for section in self.sections: + if 'CPU architecture' in section: + return section['CPU architecture'] + if 'architecture' in section: + return section['architecture'] + return None @property @memoized - def cpu_names(self): - cpu_names = [] - global_name = None - for section in self.sections: - if 'processor' in section: - if 'CPU part' in section: - cpu_names.append(_get_part_name(section)) - elif 'model name' in section: - cpu_names.append(_get_model_name(section)) - else: - cpu_names.append(None) - elif 'CPU part' in section: - global_name = _get_part_name(section) + def cpu_names(self) -> List[caseless_string]: + """ + A list of CPU names derived from fields like ``CPU part`` or ``model name``. + If found globally, that name is reused for each CPU. If found per-CPU, + you get multiple entries. + + :return: List of CPU names, one per processor entry. + """ + cpu_names: List[Optional[str]] = [] + global_name: Optional[str] = None + if self.sections: + for section in self.sections: + if 'processor' in section: + if 'CPU part' in section: + cpu_names.append(_get_part_name(section)) + elif 'model name' in section: + cpu_names.append(_get_model_name(section)) + else: + cpu_names.append(None) + elif 'CPU part' in section: + global_name = _get_part_name(section) return [caseless_string(c or global_name) for c in cpu_names] - def __init__(self, text): - self.sections = None - self.text = None + def __init__(self, text: str): + self.sections: List[Dict[str, str]] = [] + self.text = '' self.parse(text) @memoized - def get_cpu_features(self, cpuid=0): - global_features = [] - for section in self.sections: - if 'processor' in section: - if int(section.get('processor')) != cpuid: - continue - if 'Features' in section: - return section.get('Features').split() + def get_cpu_features(self, cpuid: int = 0) -> List[str]: + """ + get the Features field of the specified cpu + """ + global_features: List[str] = [] + if self.sections: + for section in self.sections: + if 'processor' in section: + if int(section.get('processor') or -1) != cpuid: + continue + if 'Features' in section: + return section.get('Features', '').split() + elif 'flags' in section: + return section.get('flags', '').split() + elif 'Features' in section: + global_features = section.get('Features', '').split() elif 'flags' in section: - return section.get('flags').split() - elif 'Features' in section: - global_features = section.get('Features').split() - elif 'flags' in section: - global_features = section.get('flags').split() + global_features = section.get('flags', '').split() return global_features - def parse(self, text): + def parse(self, text: str) -> None: + """ + Parse the provided ``/proc/cpuinfo`` text, splitting it into separate + sections for each CPU. + + :param text: The full multiline content of /proc/cpuinfo. + """ self.sections = [] - current_section = {} + current_section: Dict[str, str] = {} self.text = text.strip() for line in self.text.split('\n'): line = line.strip() @@ -2769,37 +4364,26 @@ class KernelVersion(object): :ivar release: Version number/revision string. Typical output of ``uname -r`` - :type release: str :ivar version: Extra version info (aside from ``release``) reported by ``uname`` - :type version: str :ivar version_number: Main version number (e.g. 3 for Linux 3.18) - :type version_number: int :ivar major: Major version number (e.g. 18 for Linux 3.18) - :type major: int :ivar minor: Minor version number for stable kernels (e.g. 9 for 4.9.9). May be None - :type minor: int :ivar rc: Release candidate number (e.g. 3 for Linux 4.9-rc3). May be None. - :type rc: int :ivar commits: Number of additional commits on the branch. May be None. - :type commits: int :ivar sha1: Kernel git revision hash, if available (otherwise None) - :type sha1: str :ivar android_version: Android version, if available (otherwise None) - :type android_version: int :ivar gki_abi: GKI kernel abi, if available (otherwise None) - :type gki_abi: str :ivar parts: Tuple of version number components. Can be used for lexicographically comparing kernel versions. - :type parts: tuple(int) """ - def __init__(self, version_string): + def __init__(self, version_string: str): if ' #' in version_string: release, version = version_string.split(' #') - self.release = release - self.version = version + self.release: str = release + self.version: str = version elif version_string.startswith('#'): self.release = '' self.version = version_string @@ -2807,15 +4391,15 @@ def __init__(self, version_string): self.release = version_string self.version = '' - self.version_number = None - self.major = None - self.minor = None - self.sha1 = None - self.rc = None - self.commits = None - self.gki_abi = None - self.android_version = None - match = KVERSION_REGEX.match(version_string) + self.version_number: Optional[int] = None + self.major: Optional[int] = None + self.minor: Optional[int] = None + self.sha1: Optional[str] = None + self.rc: Optional[int] = None + self.commits: Optional[int] = None + self.gki_abi: Optional[str] = None + self.android_version: Optional[int] = None + match: Optional[Match[str]] = KVERSION_REGEX.match(version_string) if match: groups = match.groupdict() self.version_number = int(groups['version']) @@ -2833,7 +4417,7 @@ def __init__(self, version_string): if groups['android_version'] is not None: self.android_version = int(match.group('android_version')) - self.parts = (self.version_number, self.major, self.minor) + self.parts: Tuple[Optional[int], Optional[int], Optional[int]] = (self.version_number, self.major, self.minor) def __str__(self): return '{} {}'.format(self.release, self.version) @@ -2841,68 +4425,132 @@ def __str__(self): __repr__ = __str__ -class HexInt(long): +class HexInt(int): """ - Subclass of :class:`int` that uses hexadecimal formatting by default. + An int subclass that is displayed in hexadecimal form. + + Example usage: + + .. code-block:: python + + val = HexInt('FF') # Parse hex string as int + print(val) # Prints: 0xff + print(int(val)) # Prints: 255 """ - def __new__(cls, val=0, base=16): + def __new__(cls, val: Union[str, int, bytearray] = 0, base=16): + """ + Construct a HexInt object, interpreting ``val`` as a base-16 value + unless it's already a number or bytearray. + + :param val: The initial value. If str, is parsed as base-16 by default; + if int or bytearray, used directly. + :param base: Numerical base (defaults to 16). + :raises TypeError: If ``val`` is not a supported type (str, int, or bytearray). + """ super_new = super(HexInt, cls).__new__ if isinstance(val, Number): return super_new(cls, val) + elif isinstance(val, bytearray): + val = int.from_bytes(val, byteorder=sys.byteorder) + return super(HexInt, cls).__new__(cls, val) + elif isinstance(val, str): + return super(HexInt, cls).__new__(cls, int(val, base)) else: - return super_new(cls, val, base=base) + raise TypeError("Unsupported type for HexInt") def __str__(self): + """ + Return a hexadecimal string representation of the integer, stripping + any trailing ``L`` in Python 2.x. + """ return hex(self).strip('L') class KernelConfigTristate(Enum): + """ + Represents a kernel config option that may be ``y``, ``n``, or ``m``. + Commonly seen in kernel ``.config`` files as: + + - ``CONFIG_FOO=y`` + - ``CONFIG_BAR=n`` + - ``CONFIG_BAZ=m`` + + Enum members: + * ``YES`` -> 'y' + * ``NO`` -> 'n' + * ``MODULE`` -> 'm' + """ YES = 'y' NO = 'n' MODULE = 'm' def __bool__(self): """ - Allow using this enum to represent bool Kconfig type, although it is - technically different from tristate. + Allow usage in boolean contexts: + + * True if the config is 'y' or 'm' + * False if the config is 'n' """ return self in (self.YES, self.MODULE) def __nonzero__(self): """ - For Python 2.x compatibility. + Python 2.x compatibility for boolean evaluation. """ return self.__bool__() @classmethod - def from_str(cls, str_): + def from_str(cls, str_: str) -> 'KernelConfigTristate': + """ + Convert a kernel config string ('y', 'n', or 'm') to the corresponding + enum member. + + :param str_: The single-character string from kernel config. + :return: The enum member that matches the provided string. + :raises ValueError: If the string is not 'y', 'n', or 'm'. + """ for state in cls: if state.value == str_: return state raise ValueError('No kernel config tristate value matches "{}"'.format(str_)) -class TypedKernelConfig(Mapping): +class TypedKernelConfig(Mapping): # type: ignore """ - Mapping-like typed version of :class:`KernelConfig`. + A mapping-like object representing typed kernel config parameters. Keys are + canonicalized config names (e.g. "CONFIG_FOO"), and values may be strings, ints, + :class:`HexInt`, or :class:`KernelConfigTristate`. + + :param not_set_regex: A regex that matches lines in the form ``# CONFIG_ABC is not set``. - Values are either :class:`str`, :class:`int`, - :class:`KernelConfigTristate`, or :class:`HexInt`. ``hex`` Kconfig type is - mapped to :class:`HexInt` and ``bool`` to :class:`KernelConfigTristate`. + :param mapping: An optional initial mapping of config keys to string values. + Typically set by parsing a kernel .config file or /proc/config.gz content. """ not_set_regex = re.compile(r'# (\S+) is not set') @staticmethod - def get_config_name(name): + def get_config_name(name: str) -> str: + """ + Ensure the config name starts with 'CONFIG_', returning + the canonical form. + + :param name: A raw config key name (e.g. 'ABC'). + :return: The canonical name (e.g. 'CONFIG_ABC'). + """ name = name.upper() if not name.startswith('CONFIG_'): name = 'CONFIG_' + name return name - def __init__(self, mapping=None): + def __init__(self, mapping: Optional[Maptype] = None): + """ + Initialize a typed kernel config from an existing dictionary or None. + + :param mapping: Existing config data (raw strings), keyed by config name. + """ mapping = mapping if mapping is not None else {} - self._config = { + self._config: Dict[str, str] = { # Ensure we use the canonical name of the config keys for internal # representation self.get_config_name(k): v @@ -2910,34 +4558,45 @@ def __init__(self, mapping=None): } @classmethod - def from_str(cls, text): + def from_str(cls, text: str) -> 'TypedKernelConfig': """ - Build a :class:`TypedKernelConfig` out of the string content of a - Kconfig file. + Build a typed config by parsing raw text of a kernel config file. + + :param text: Contents of the kernel config, including lines such as + ``CONFIG_ABC=y`` or ``# CONFIG_DEF is not set``. + :return: A :class:`TypedKernelConfig` reflecting typed config values. """ return cls(cls._parse_text(text)) @staticmethod - def _val_to_str(val): + def _val_to_str(val: Optional[Union[KernelConfigTristate, str]]) -> str: "Convert back values to Kconfig-style string value" # Special case the gracefully handle the output of get() if val is None: - return None + return "" elif isinstance(val, KernelConfigTristate): return val.value - elif isinstance(val, basestring): + elif isinstance(val, str): return '"{}"'.format(val.strip('"')) else: return str(val) def __str__(self): + """ + Convert the typed config back to a kernel config-style string, e.g. + "CONFIG_FOO=y\nCONFIG_BAR=\"value\"\n..." + + :return: A multi-line string representation of the typed config. + """ return '\n'.join( '{}={}'.format(k, self._val_to_str(v)) for k, v in self.items() ) @staticmethod - def _parse_val(k, v): + def _parse_val(k: str, v: Union[str, int, HexInt, + KernelConfigTristate]) -> Optional[Union[KernelConfigTristate, + HexInt, int, str]]: """ Parse a value of types handled by Kconfig: * string @@ -2949,43 +4608,52 @@ def _parse_val(k, v): Since bool cannot be distinguished from tristate, tristate is always used. :meth:`KernelConfigTristate.__bool__` will allow using it as a bool though, so it should not impact user code. + + :param k: The config key name (not used heavily). + :param v: The raw string or typed object. + :return: The typed version of the value. """ if not v: return None - # Handle "string" type - if v.startswith('"'): - # Strip enclosing " - return v[1:-1] + if isinstance(v, str): + # Handle "string" type + if v.startswith('"'): + # Strip enclosing " + return v[1:-1] - else: - try: - # Handles "bool" and "tristate" types - return KernelConfigTristate.from_str(v) - except ValueError: - pass + else: + try: + # Handles "bool" and "tristate" types + return KernelConfigTristate.from_str(v) + except ValueError: + pass - try: - # Handles "int" type - return int(v) - except ValueError: - pass + try: + # Handles "int" type + return int(v) + except ValueError: + pass - try: - # Handles "hex" type - return HexInt(v) - except ValueError: - pass + try: + # Handles "hex" type + return HexInt(v) + except ValueError: + pass - # If no type could be parsed - raise ValueError('Could not parse Kconfig key: {}={}'.format( + # If no type could be parsed + raise ValueError('Could not parse Kconfig key: {}={}'.format( k, v ), k, v - ) + ) + return None @classmethod - def _parse_text(cls, text): - config = {} + def _parse_text(cls, text: str) -> Dict[str, Optional[Union[KernelConfigTristate, HexInt, int, str]]]: + """ + parse the kernel config text and create a dictionary of the configs + """ + config: Dict[str, Optional[Union[KernelConfigTristate, HexInt, int, str]]] = {} for line in text.splitlines(): line = line.strip() @@ -2996,19 +4664,19 @@ def _parse_text(cls, text): if line.startswith('#'): match = cls.not_set_regex.search(line) if match: - value = 'n' - name = match.group(1) + value: str = 'n' + name: str = match.group(1) else: continue else: name, value = line.split('=', 1) name = cls.get_config_name(name.strip()) - value = cls._parse_val(name, value.strip()) - config[name] = value + parsed_value: Optional[Union[KernelConfigTristate, HexInt, int, str]] = cls._parse_val(name, value.strip()) + config[name] = parsed_value return config - def __getitem__(self, name): + def __getitem__(self, name: str) -> str: name = self.get_config_name(name) try: return self._config[name] @@ -3024,27 +4692,43 @@ def __iter__(self): def __len__(self): return len(self._config) +# FIXME - annotating name as str gives some type errors as Mapping superclass expects object def __contains__(self, name): name = self.get_config_name(name) return name in self._config - def like(self, name): + def like(self, name: str) -> Dict[str, str]: + """ + Return a dictionary of key-value pairs where the keys match the given regular expression pattern. + """ regex = re.compile(name, re.I) return { k: v for k, v in self.items() if regex.search(k) } - def is_enabled(self, name): + def is_enabled(self, name: str) -> bool: + """ + true if the config is enabled in kernel + """ return self.get(name) is KernelConfigTristate.YES - def is_module(self, name): + def is_module(self, name: str) -> bool: + """ + true if the config is of Module type + """ return self.get(name) is KernelConfigTristate.MODULE - def is_not_set(self, name): + def is_not_set(self, name: str) -> bool: + """ + true if the config is not enabled + """ return self.get(name) is KernelConfigTristate.NO - def has(self, name): + def has(self, name: str) -> bool: + """ + true if the config is either enabled or it is a module + """ return self.is_enabled(name) or self.is_module(name) @@ -3055,10 +4739,10 @@ class KernelConfig(object): This class does not provide a Mapping API and only return string values. """ @staticmethod - def get_config_name(name): + def get_config_name(name: str) -> str: return TypedKernelConfig.get_config_name(name) - def __init__(self, text): + def __init__(self, text: str): # Expose typed_config as a non-private attribute, so that user code # needing it can get it from any existing producer of KernelConfig. self.typed_config = TypedKernelConfig.from_str(text) @@ -3070,55 +4754,105 @@ def __bool__(self): not_set_regex = TypedKernelConfig.not_set_regex - def iteritems(self): + def iteritems(self) -> Iterator[Tuple[str, str]]: + """ + Iterate over the items in the typed configuration, converting each value to a string. + """ for k, v in self.typed_config.items(): yield (k, self.typed_config._val_to_str(v)) items = iteritems - def get(self, name, strict=False): + def get(self, name: str, strict: bool = False) -> Optional[str]: + """ + Retrieve a value from the typed configuration and convert it to a string. + """ if strict: - val = self.typed_config[name] + val: Optional[str] = self.typed_config[name] else: val = self.typed_config.get(name) return self.typed_config._val_to_str(val) - def like(self, name): + def like(self, name: str) -> Dict[str, str]: + """ + Return a dictionary of key-value pairs where the keys match the given regular expression pattern. + """ return { k: self.typed_config._val_to_str(v) for k, v in self.typed_config.like(name).items() } - def is_enabled(self, name): + def is_enabled(self, name: str) -> bool: + """ + true if the config is enabled in kernel + """ return self.typed_config.is_enabled(name) - def is_module(self, name): + def is_module(self, name: str) -> bool: + """ + true if the config is of Module type + """ return self.typed_config.is_module(name) - def is_not_set(self, name): + def is_not_set(self, name: str) -> bool: + """ + true if the config is not enabled + """ return self.typed_config.is_not_set(name) - def has(self, name): + def has(self, name: str) -> bool: + """ + true if the config is either enabled or it is a module + """ return self.typed_config.has(name) class LocalLinuxTarget(LinuxTarget): + """ + A specialized :class:`Target` subclass representing the local Linux system + (i.e., no remote connection needed). In many respects, this parallels + :class:`LinuxTarget`, but uses :class:`LocalConnection` under the hood. + + :param connection_settings: Dictionary specifying local connection options + (often unused or minimal). + :param platform: A ``Platform`` object if you want to specify architecture, + kernel version, etc. If None, a default is inferred from the host system. + :param working_directory: A writable directory on the local machine for devlibs + temporary operations. If None, a subfolder of /tmp or similar is often used. + :param executables_directory: Directory for installing binaries from devlib, + if needed. + :param connect: Whether to connect (initialize local environment) immediately. + :param modules: Additional devlib modules to load at construction time. + :param load_default_modules: If True, also load modules listed in + :attr:`default_modules`. + :param shell_prompt: Regex matching the local shell prompt (usually not used + since local commands are run directly). + :param conn_cls: Connection class to use, typically :class:`LocalConnection`. + :param is_container: If True, indicates we’re running in a container environment + rather than the full host OS. + :param max_async: Maximum concurrent asynchronous commands allowed. + + """ def __init__(self, - connection_settings=None, - platform=None, - working_directory=None, - executables_directory=None, - connect=True, - modules=None, - load_default_modules=True, - shell_prompt=DEFAULT_SHELL_PROMPT, - conn_cls=LocalConnection, - is_container=False, - max_async=50, - tmp_directory=None, + connection_settings: Optional[UserConnectionSettings] = None, + platform: Optional[Platform] = None, + working_directory: Optional[str] = None, + executables_directory: Optional[str] = None, + connect: bool = True, + modules: Optional[Dict[str, Dict[str, Type[Module]]]] = None, + load_default_modules: bool = True, + shell_prompt: Pattern[str] = DEFAULT_SHELL_PROMPT, + conn_cls: 'InitCheckpointMeta' = LocalConnection, + is_container: bool = False, + max_async: int = 50, + tmp_directory: Optional[str] = None, ): + """ + Initialize a LocalLinuxTarget, representing the local machine as the devlib + target. Optionally connect and load modules immediately. + """ super(LocalLinuxTarget, self).__init__(connection_settings=connection_settings, platform=platform, working_directory=working_directory, @@ -3133,141 +4867,187 @@ def __init__(self, tmp_directory=tmp_directory, ) - def _resolve_paths(self): + def _resolve_paths(self) -> None: + """ + Resolve or finalize local working directories/executables directories. + By default, uses a subfolder of /tmp if none is set. + """ if self.working_directory is None: self.working_directory = '/tmp/devlib-target' -def _get_model_name(section): - name_string = section['model name'] - parts = name_string.split('@')[0].strip().split() +def _get_model_name(section: Dict[str, str]) -> str: + """ + get model name from section of cpu info + """ + name_string: str = section['model name'] + parts: List[str] = name_string.split('@')[0].strip().split() return ' '.join([p for p in parts if '(' not in p and p != 'CPU']) -def _get_part_name(section): - implementer = section.get('CPU implementer', '0x0') - part = section['CPU part'] - variant = section.get('CPU variant', '0x0') +def _get_part_name(section: Dict[str, str]) -> str: + """ + get part name from cpu info + """ + implementer: str = section.get('CPU implementer', '0x0') + part: str = section['CPU part'] + variant: str = section.get('CPU variant', '0x0') name = get_cpu_name(*list(map(integer, [implementer, part, variant]))) if name is None: name = f'{implementer}/{part}/{variant}' return name -def _build_path_tree(path_map, basepath, sep=os.path.sep, dictcls=dict): +Node = Union[str, Dict[str, 'Node']] + + +def _build_path_tree(path_map: Dict[str, str], basepath: str, + sep: str = os.path.sep, dictcls=dict) -> Union[str, Dict[str, 'Node']]: """ Convert a flat mapping of paths to values into a nested structure of - dict-line object (``dict``'s by default), mirroring the directory hierarchy + dict-like object (``dict``'s by default), mirroring the directory hierarchy represented by the paths relative to ``basepath``. """ - def process_node(node, path, value): + def process_node(node: 'Node', path: str, value: str): parts = path.split(sep, 1) - if len(parts) == 1: # leaf + if len(parts) == 1 and not isinstance(node, str): # leaf node[parts[0]] = value else: # branch - if parts[0] not in node: - node[parts[0]] = dictcls() - process_node(node[parts[0]], parts[1], value) + if not isinstance(node, str): + if parts[0] not in node: + node[parts[0]] = dictcls() + process_node(node[parts[0]], parts[1], value) - relpath_map = {os.path.relpath(p, basepath): v - for p, v in path_map.items()} + relpath_map: Dict[str, str] = {os.path.relpath(p, basepath): v + for p, v in path_map.items()} if len(relpath_map) == 1 and list(relpath_map.keys())[0] == '.': - result = list(relpath_map.values())[0] + result: Union[str, Dict[str, Any]] = list(relpath_map.values())[0] else: result = dictcls() for path, value in relpath_map.items(): - process_node(result, path, value) + if not isinstance(result, str): + process_node(result, path, value) return result class ChromeOsTarget(LinuxTarget): """ - Class for interacting with ChromeOS targets. + :class:`ChromeOsTarget` is a subclass of :class:`LinuxTarget` with + additional features specific to a device running ChromeOS for example, + if supported, its own android container which can be accessed via the + ``android_container`` attribute. When making calls to or accessing + properties and attributes of the ChromeOS target, by default they will + be applied to Linux target as this is where the majority of device + configuration will be performed and if not available, will fall back to + using the android container if available. This means that all the + available methods from + :class:`LinuxTarget` and :class:`AndroidTarget` are available for + :class:`ChromeOsTarget` if the device supports android otherwise only the + :class:`LinuxTarget` methods will be available. + + :param working_directory: This is the location of the working directory to + be used for the Linux target container. If not specified will default to + ``"/mnt/stateful_partition/devlib-target"``. + + :param android_working_directory: This is the location of the working + directory to be used for the android container. If not specified it will + use the working directory default for :class:`AndroidTarget.`. + + :param android_executables_directory: This is the location of the + executables directory to be used for the android container. If not + specified will default to a ``bin`` subdirectory in the + ``android_working_directory.`` + + :param package_data_directory: This is the location of the data stored + for installed Android packages on the device. """ - os = 'chromeos' + os: str = 'chromeos' # pylint: disable=too-many-locals,too-many-arguments def __init__(self, - connection_settings=None, - platform=None, - working_directory=None, - executables_directory=None, - android_working_directory=None, - android_executables_directory=None, - connect=True, - modules=None, - load_default_modules=True, - shell_prompt=DEFAULT_SHELL_PROMPT, - package_data_directory="/data/data", - is_container=False, - max_async=50, - tmp_directory=None, + connection_settings: Optional[UserConnectionSettings] = None, + platform: Optional[Platform] = None, + working_directory: Optional[str] = None, + executables_directory: Optional[str] = None, + android_working_directory: Optional[str] = None, + android_executables_directory: Optional[str] = None, + connect: bool = True, + modules: Optional[Dict[str, Dict[str, Type[Module]]]] = None, + load_default_modules: bool = True, + shell_prompt: Pattern[str] = DEFAULT_SHELL_PROMPT, + package_data_directory: str = "/data/data", + is_container: bool = False, + max_async: int = 50, + tmp_directory: Optional[str] = None, ): + """ + Initialize a ChromeOsTarget for interacting with a device running Chrome OS + in developer mode (exposing SSH). + """ - self.supports_android = None - self.android_container = None + self.supports_android: Optional[bool] = None + self.android_container: Optional[AndroidTarget] = None # Pull out ssh connection settings - ssh_conn_params = ['host', 'username', 'password', 'keyfile', - 'port', 'timeout', 'sudo_cmd', - 'strict_host_check', 'use_scp', - 'total_transfer_timeout', 'poll_transfers', - 'start_transfer_poll_delay'] - self.ssh_connection_settings = {} - self.ssh_connection_settings.update( - (key, value) - for key, value in connection_settings.items() - if key in ssh_conn_params - ) + ssh_conn_params: List[str] = ['host', 'username', 'password', 'keyfile', + 'port', 'timeout', 'sudo_cmd', + 'strict_host_check', 'use_scp', + 'total_transfer_timeout', 'poll_transfers', + 'start_transfer_poll_delay'] + self.ssh_connection_settings: SshUserConnectionSettings = {} + if connection_settings: + update_dict = cast(SshUserConnectionSettings, + {key: value for key, value in connection_settings.items() if key in ssh_conn_params}) + self.ssh_connection_settings.update(update_dict) super().__init__(connection_settings=self.ssh_connection_settings, - platform=platform, - working_directory=working_directory, - executables_directory=executables_directory, - connect=False, - modules=modules, - load_default_modules=load_default_modules, - shell_prompt=shell_prompt, - conn_cls=SshConnection, - is_container=is_container, - max_async=max_async, - tmp_directory=tmp_directory, - ) + platform=platform, + working_directory=working_directory, + executables_directory=executables_directory, + connect=False, + modules=modules, + load_default_modules=load_default_modules, + shell_prompt=shell_prompt, + conn_cls=SshConnection, + is_container=is_container, + max_async=max_async, + tmp_directory=tmp_directory) # We can't determine if the target supports android until connected to the linux host so # create unconditionally. # Pull out adb connection settings adb_conn_params = ['device', 'adb_server', 'adb_port', 'timeout'] - self.android_connection_settings = {} - self.android_connection_settings.update( - (key, value) - for key, value in connection_settings.items() - if key in adb_conn_params - ) - - # If adb device is not explicitly specified use same as ssh host - if not connection_settings.get('device', None): - self.android_connection_settings['device'] = connection_settings.get('host', None) - - self.android_container = AndroidTarget(connection_settings=self.android_connection_settings, - platform=platform, - working_directory=android_working_directory, - executables_directory=android_executables_directory, - connect=False, - load_default_modules=False, - shell_prompt=shell_prompt, - conn_cls=AdbConnection, - package_data_directory=package_data_directory, - is_container=True) - if connect: - self.connect() - - def __getattr__(self, attr): + self.android_connection_settings: AdbUserConnectionSettings = {} + if connection_settings: + update_dict_adb = cast(AdbUserConnectionSettings, + {key: value for key, value in connection_settings.items() if key in adb_conn_params}) + self.android_connection_settings.update(update_dict_adb) + + # If adb device is not explicitly specified use same as ssh host + if not connection_settings.get('device', None): + device = connection_settings.get('host', None) + if device: + self.android_connection_settings['device'] = device + + self.android_container = AndroidTarget(connection_settings=self.android_connection_settings, + platform=platform, + working_directory=android_working_directory, + executables_directory=android_executables_directory, + connect=False, + load_default_modules=False, + shell_prompt=shell_prompt, + conn_cls=AdbConnection, + package_data_directory=package_data_directory, + is_container=True) + if connect: + self.connect() + + def __getattr__(self, attr: str): """ By default use the linux target methods and attributes however, if not present, use android implementation if available. @@ -3280,7 +5060,7 @@ def __getattr__(self, attr): raise @asyn.asyncf - async def connect(self, timeout=30, check_boot_completed=True, max_async=None): + async def connect(self, timeout: int = 30, check_boot_completed: bool = True, max_async: Optional[int] = None) -> None: super().connect( timeout=timeout, check_boot_completed=check_boot_completed, @@ -3291,11 +5071,15 @@ async def connect(self, timeout=30, check_boot_completed=True, max_async=None): if self.supports_android is None: self.supports_android = self.directory_exists('/opt/google/containers/android/') - if self.supports_android: + if self.supports_android and self.android_container: self.android_container.connect(timeout) else: self.android_container = None - def _resolve_paths(self): + def _resolve_paths(self) -> None: + """ + Finalize any path logic specific to Chrome OS. Some directories + may be restricted or read-only, depending on dev mode settings. + """ if self.working_directory is None: self.working_directory = '/mnt/stateful_partition/devlib-target' diff --git a/devlib/utils/android.py b/devlib/utils/android.py old mode 100755 new mode 100644 index 001cb93be..aa042e08b --- a/devlib/utils/android.py +++ b/devlib/utils/android.py @@ -1,4 +1,4 @@ -# Copyright 2013-2018 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +16,10 @@ """ Utility functions for working with Android devices through adb. - """ # pylint: disable=E1103 import functools import glob -import logging import os import pexpect import re @@ -38,20 +36,39 @@ from lxml import etree from shlex import quote -from devlib.exception import TargetTransientError, TargetStableError, HostError, TargetTransientCalledProcessError, TargetStableCalledProcessError, AdbRootError -from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams, get_subprocess -from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenTransferHandle - - -logger = logging.getLogger('android') - -MAX_ATTEMPTS = 5 -AM_START_ERROR = re.compile(r"Error: Activity.*") -AAPT_BADGING_OUTPUT = re.compile(r"no dump ((file)|(apk)) specified", re.IGNORECASE) +from devlib.exception import (TargetTransientError, TargetStableError, HostError, + TargetTransientCalledProcessError, TargetStableCalledProcessError, AdbRootError) +from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams, get_subprocess, get_logger +from devlib.connection import (ConnectionBase, AdbBackgroundCommand, + PopenTransferHandle) + +from typing import (Optional, TYPE_CHECKING, cast, Tuple, Union, + List, DefaultDict, Pattern, Dict, Iterator, + Match, Callable) +from collections.abc import Generator +from typing_extensions import Required, TypedDict, Literal +if TYPE_CHECKING: + from devlib.utils.annotation_helpers import SubprocessCommand + from threading import Lock + from lxml.etree import _ElementTree, _Element, XMLParser + from devlib.platform import Platform + from subprocess import Popen, CompletedProcess + from devlib.target import AndroidTarget + from io import TextIOWrapper + from tempfile import _TemporaryFileWrapper + from pexpect import spawn + +PartsType = Tuple[Union[str, Tuple[str, ...]], ...] + +logger = get_logger('android') + +MAX_ATTEMPTS: int = 5 +AM_START_ERROR: Pattern[str] = re.compile(r"Error: Activity.*") +AAPT_BADGING_OUTPUT: Pattern[str] = re.compile(r"no dump ((file)|(apk)) specified", re.IGNORECASE) # See: # http://developer.android.com/guide/topics/manifest/uses-sdk-element.html#ApiLevels -ANDROID_VERSION_MAP = { +ANDROID_VERSION_MAP: Dict[int, str] = { 29: 'Q', 28: 'PIE', 27: 'OREO_MR1', @@ -84,96 +101,220 @@ } # See https://developer.android.com/reference/android/content/Intent.html#setFlags(int) -INTENT_FLAGS = { - 'ACTIVITY_NEW_TASK' : 0x10000000, - 'ACTIVITY_CLEAR_TASK' : 0x00008000 +INTENT_FLAGS: Dict[str, int] = { + 'ACTIVITY_NEW_TASK': 0x10000000, + 'ACTIVITY_CLEAR_TASK': 0x00008000 } + class AndroidProperties(object): + """ + Represents Android system properties as reported by the ``getprop`` command. + Allows easy retrieval of property values. - def __init__(self, text): - self._properties = {} + :param text: Full string output from ``adb shell getprop`` (or similar). + """ + def __init__(self, text: str): + self._properties: Dict[str, str] = {} self.parse(text) - def parse(self, text): + def parse(self, text: str) -> None: + """ + Parse the output text and update the internal property dictionary. + + :param text: String containing the property lines. + """ self._properties = dict(re.findall(r'\[(.*?)\]:\s+\[(.*?)\]', text)) - def iteritems(self): + def iteritems(self) -> Iterator[Tuple[str, str]]: + """ + Return an iterator of (property_key, property_value) pairs. + + :returns: An iterator of tuples like (key, value). + """ return iter(self._properties.items()) def __iter__(self): + """ + Iterate over the property keys. + """ return iter(self._properties) - def __getattr__(self, name): + def __getattr__(self, name: str): + """ + Return a property value by attribute-style lookup. + Defaults to None if the property is missing. + """ return self._properties.get(name) __getitem__ = __getattr__ class AdbDevice(object): + """ + Represents a single device as seen by ``adb devices`` (usually a USB or IP + device). - def __init__(self, name, status): + :param name: The serial number or identifier of the device. + :param status: The device status, e.g. "device", "offline", or "unauthorized". + """ + def __init__(self, name: str, status: str): self.name = name self.status = status - # pylint: disable=undefined-variable - def __cmp__(self, other): + # replace __cmp__ of python 2 with explicit comparison methods + # of python 3 + def __lt__(self, other: Union['AdbDevice', str]) -> bool: + """ + Compare this device's name with another device or string for ordering. + """ if isinstance(other, AdbDevice): - return cmp(self.name, other.name) - else: - return cmp(self.name, other) + return self.name < other.name + return self.name < other - def __str__(self): + def __eq__(self, other: object) -> bool: + """ + Check if this device's name matches another device's name or a string. + """ + if isinstance(other, AdbDevice): + return self.name == other.name + return self.name == other + + def __le__(self, other: Union['AdbDevice', str]) -> bool: + """ + Test if this device's name is <= another device/string. + """ + return self < other or self == other + + def __gt__(self, other: Union['AdbDevice', str]) -> bool: + """ + Test if this device's name is > another device/string. + """ + return not self <= other + + def __ge__(self, other: Union['AdbDevice', str]) -> bool: + """ + Test if this device's name is >= another device/string. + """ + return not self < other + + def __ne__(self, other: object) -> bool: + """ + Invert the __eq__ comparison. + """ + return not self == other + + def __str__(self) -> str: + """ + Return a string representation of this device for debugging. + """ return 'AdbDevice({}, {})'.format(self.name, self.status) __repr__ = __str__ +class BuildToolsInfo(TypedDict, total=False): + """ + Typed dictionary capturing build tools info. + + :param build_tools: The path to the build-tools directory. + :param aapt: Path to the aapt or aapt2 binary. + :param aapt_version: Integer 1 or 2 indicating which aapt is used. + """ + build_tools: Required[Optional[str]] + aapt: Required[Optional[str]] + aapt_version: Required[Optional[int]] + + +class Android_Env_Type(TypedDict, total=False): + """ + Typed dictionary representing environment paths for Android tools. + + :param android_home: ANDROID_HOME path, if set. + :param platform_tools: Path to the 'platform-tools' directory containing adb/fastboot. + :param adb: Path to the 'adb' executable. + :param fastboot: Path to the 'fastboot' executable. + :param build_tools: Path to the 'build-tools' directory if available. + :param aapt: Path to aapt or aapt2, if found. + :param aapt_version: 1 or 2 indicating which aapt variant is used. + """ + android_home: Required[Optional[str]] + platform_tools: Required[str] + adb: Required[str] + fastboot: Required[str] + build_tools: Required[Optional[str]] + aapt: Required[Optional[str]] + aapt_version: Required[Optional[int]] + + +Android_Env_TypeKeys = Union[Literal['android_home'], + Literal['platform_tools'], + Literal['adb'], + Literal['fastboot'], + Literal['build_tools'], + Literal['aapt'], + Literal['aapt_version']] + + class ApkInfo(object): + """ + Extracts and stores metadata about an APK, including package name, version, + supported ABIs, permissions, etc. The parsing relies on the 'aapt' or 'aapt2' + command from Android build-tools. - version_regex = re.compile(r"name='(?P[^']+)' versionCode='(?P[^']+)' versionName='(?P[^']+)'") - name_regex = re.compile(r"name='(?P[^']+)'") - permission_regex = re.compile(r"name='(?P[^']+)'") - activity_regex = re.compile(r'\s*A:\s*android:name\(0x\d+\)=".(?P\w+)"') + :param path: Optional path to the APK file on the host. If provided, it is + immediately parsed. + """ + version_regex: Pattern[str] = re.compile(r"name='(?P[^']+)' versionCode='(?P[^']+)' versionName='(?P[^']+)'") + name_regex: Pattern[str] = re.compile(r"name='(?P[^']+)'") + permission_regex: Pattern[str] = re.compile(r"name='(?P[^']+)'") + activity_regex: Pattern[str] = re.compile(r'\s*A:\s*android:name\(0x\d+\)=".(?P\w+)"') - def __init__(self, path=None): + def __init__(self, path: Optional[str] = None): self.path = path - self.package = None - self.activity = None - self.label = None - self.version_name = None - self.version_code = None - self.native_code = None - self.permissions = [] - self._apk_path = None - self._activities = None - self._methods = None - self._aapt = _ANDROID_ENV.get_env('aapt') - self._aapt_version = _ANDROID_ENV.get_env('aapt_version') + self.package: Optional[str] = None + self.activity: Optional[str] = None + self.label: Optional[str] = None + self.version_name: Optional[str] = None + self.version_code: Optional[str] = None + self.native_code: Optional[List[str]] = None + self.permissions: List[str] = [] + self._apk_path: Optional[str] = None + self._activities: Optional[List[str]] = None + self._methods: Optional[List[Tuple[str, str]]] = None + self._aapt: str = cast(str, _ANDROID_ENV.get_env('aapt')) + self._aapt_version: int = cast(int, _ANDROID_ENV.get_env('aapt_version')) if path: self.parse(path) # pylint: disable=too-many-branches - def parse(self, apk_path): - output = self._run([self._aapt, 'dump', 'badging', apk_path]) + def parse(self, apk_path: str) -> None: + """ + Parse the given APK file with the aapt or aapt2 utility, retrieving + metadata such as package name, version, and permissions. + + :param apk_path: The path to the APK file on the host system. + :raises HostError: If aapt fails to run or returns an error message. + """ + output: str = self._run([self._aapt, 'dump', 'badging', apk_path]) for line in output.split('\n'): if line.startswith('application-label:'): self.label = line.split(':')[1].strip().replace('\'', '') elif line.startswith('package:'): - match = self.version_regex.search(line) + match: Optional[Match[str]] = self.version_regex.search(line) if match: self.package = match.group('name') self.version_code = match.group('vcode') self.version_name = match.group('vname') elif line.startswith('launchable-activity:'): match = self.name_regex.search(line) - self.activity = match.group('name') + self.activity = match.group('name') if match else None elif line.startswith('native-code'): - apk_abis = [entry.strip() for entry in line.split(':')[1].split("'") if entry.strip()] - mapped_abis = [] + apk_abis: List[str] = [entry.strip() for entry in line.split(':')[1].split("'") if entry.strip()] + mapped_abis: List[str] = [] for apk_abi in apk_abis: - found = False + found: bool = False for abi, architectures in ABI_MAP.items(): if apk_abi in architectures: mapped_abis.append(abi) @@ -194,37 +335,49 @@ def parse(self, apk_path): self._methods = None @property - def activities(self): + def activities(self) -> List[str]: + """ + Return a list of activity names declared in this APK. + + :returns: A list of activity names found in AndroidManifest.xml. + """ if self._activities is None: - cmd = [self._aapt, 'dump', 'xmltree', self._apk_path] + cmd: List[str] = [self._aapt, 'dump', 'xmltree', self._apk_path if self._apk_path else ''] if self._aapt_version == 2: cmd += ['--file'] cmd += ['AndroidManifest.xml'] - matched_activities = self.activity_regex.finditer(self._run(cmd)) + matched_activities: Iterator[Match[str]] = self.activity_regex.finditer(self._run(cmd)) self._activities = [m.group('name') for m in matched_activities] return self._activities @property - def methods(self): + def methods(self) -> Optional[List[Tuple[str, str]]]: + """ + Return a list of (method_name, class_name) pairs, if any can be extracted + by dexdump. If no classes.dex is found or an error occurs, returns an empty list. + + :returns: A list of (method_name, class_name) tuples, or None if not parsed yet. + """ if self._methods is None: # Only try to extract once self._methods = [] with tempfile.TemporaryDirectory() as tmp_dir: - with zipfile.ZipFile(self._apk_path, 'r') as z: - try: - extracted = z.extract('classes.dex', tmp_dir) - except KeyError: - return [] - dexdump = os.path.join(os.path.dirname(self._aapt), 'dexdump') - command = [dexdump, '-l', 'xml', extracted] - dump = self._run(command) + if self._apk_path: + with zipfile.ZipFile(self._apk_path, 'r') as z: + try: + extracted: str = z.extract('classes.dex', tmp_dir) + except KeyError: + return [] + dexdump: str = os.path.join(os.path.dirname(self._aapt), 'dexdump') + command: List[str] = [dexdump, '-l', 'xml', extracted] + dump: str = self._run(command) # Dexdump from build tools v30.0.X does not seem to produce # valid xml from certain APKs so ignore errors and attempt to recover. - parser = etree.XMLParser(encoding='utf-8', recover=True) - xml_tree = etree.parse(StringIO(dump), parser) + parser: XMLParser = etree.XMLParser(encoding='utf-8', recover=True) + xml_tree: _ElementTree = etree.parse(StringIO(dump), parser) - package = [] + package: List[_Element] = [] for i in xml_tree.iter('package'): if i.attrib['name'] == self.package: package.append(i) @@ -235,11 +388,18 @@ def methods(self): for meth in klass.iter('method')]) return self._methods - def _run(self, command): + def _run(self, command: List[str]) -> str: + """ + Execute a local shell command (e.g., aapt) and return its output as a string. + + :param command: List of command arguments to run. + :returns: Combined stdout+stderr as a decoded string. + :raises HostError: If the command fails or returns a nonzero exit code. + """ logger.debug(' '.join(command)) try: - output = subprocess.check_output(command, stderr=subprocess.STDOUT) - output = output.decode(sys.stdout.encoding or 'utf-8', 'replace') + output_tmp: bytes = subprocess.check_output(command, stderr=subprocess.STDOUT) + output: str = output_tmp.decode(sys.stdout.encoding or 'utf-8', 'replace') except subprocess.CalledProcessError as e: raise HostError('Error while running "{}":\n{}' .format(command, e.output)) @@ -247,46 +407,96 @@ def _run(self, command): class AdbConnection(ConnectionBase): - + """ + A connection to an android device via ``adb`` (Android Debug Bridge). + ``adb`` is part of the Android SDK (though stand-alone versions are also + available). + + :param device: The name of the adb device. This is usually a unique hex + string for USB-connected devices, or an ip address/port + combination. To see connected devices, you can run ``adb + devices`` on the host. + :param timeout: Connection timeout in seconds. If a connection to the device + is not established within this period, :class:`HostError` + is raised. + :param platform: An optional Platform object describing hardware aspects. + :param adb_server: Allows specifying the address of the adb server to use. + :param adb_port: If specified, connect to a custom adb server port. + :param adb_as_root: Specify whether the adb server should be restarted in root mode. + :param connection_attempts: Specify how many connection attempts, 10 seconds + apart, should be attempted to connect to the device. + Defaults to 5. + :param poll_transfers: Specify whether file transfers should be polled. Polling + monitors the progress of file transfers and periodically + checks whether they have stalled, attempting to cancel + the transfers prematurely if so. + :param start_transfer_poll_delay: If transfers are polled, specify the length of + time after a transfer has started before polling + should start. + :param total_transfer_timeout: If transfers are polled, specify the total amount of time + to elapse before the transfer is cancelled, regardless + of its activity. + :param transfer_poll_period: If transfers are polled, specify the period at which + the transfers are sampled for activity. Too small values + may cause the destination size to appear the same over + one or more sample periods, causing improper transfer + cancellation. + + :raises AdbRootError: If root mode is requested but multiple connections are active or device does not allow it. + :raises HostError: If the device fails to connect or is invalid. + """ # maintains the count of parallel active connections to a device, so that # adb disconnect is not invoked untill all connections are closed - active_connections = (threading.Lock(), defaultdict(int)) + active_connections: Tuple['Lock', DefaultDict[str, int]] = (threading.Lock(), defaultdict(int)) # Track connected as root status per device - _connected_as_root = defaultdict(lambda: None) - default_timeout = 10 - ls_command = 'ls' - su_cmd = 'su -c {}' + _connected_as_root: DefaultDict[str, Optional[bool]] = defaultdict(lambda: None) + default_timeout: int = 10 + ls_command: str = 'ls' + su_cmd: str = 'su -c {}' @property - def name(self): + def name(self) -> str: + """ + :returns: The device serial number or IP:port used by this connection. + """ return self.device @property - def connected_as_root(self): + def connected_as_root(self) -> Optional[bool]: + """ + Check if the current connection is effectively root on the device. + + :returns: True if root, False if not, or None if undetermined. + """ if self._connected_as_root[self.device] is None: result = self.execute('id') self._connected_as_root[self.device] = 'uid=0(' in result return self._connected_as_root[self.device] @connected_as_root.setter - def connected_as_root(self, state): + def connected_as_root(self, state: Optional[bool]) -> None: + """ + Manually set the known state of root usage on this device connection. + + :param state: True if connected as root, False if not, None to reset. + """ self._connected_as_root[self.device] = state # pylint: disable=unused-argument def __init__( self, - device=None, - timeout=None, - platform=None, - adb_server=None, - adb_port=None, - adb_as_root=False, - connection_attempts=MAX_ATTEMPTS, - - poll_transfers=False, - start_transfer_poll_delay=30, - total_transfer_timeout=3600, - transfer_poll_period=30, + device: Optional[str] = None, + timeout: Optional[int] = None, + platform: Optional['Platform'] = None, + adb_server: Optional[str] = None, + adb_port: Optional[int] = None, + adb_as_root: bool = False, + connection_attempts: int = MAX_ATTEMPTS, + + poll_transfers: bool = False, + start_transfer_poll_delay: int = 30, + total_transfer_timeout: int = 3600, + transfer_poll_period: int = 30, ): super().__init__( poll_transfers=poll_transfers, @@ -323,19 +533,40 @@ def __init__( self._setup_ls() self._setup_su() - def push(self, sources, dest, timeout=None): + def push(self, sources: List[str], dest: str, + timeout: Optional[int] = None) -> None: + """ + Upload (push) one or more files/directories from the host to the device. + + :param sources: Paths on the host system to be pushed. + :param dest: Target path on the device. If multiple sources, dest should be a dir. + :param timeout: Max time in seconds for each file push. If exceeded, an error is raised. + """ return self._push_pull('push', sources, dest, timeout) - def pull(self, sources, dest, timeout=None): + def pull(self, sources: List[str], dest: str, + timeout: Optional[int] = None) -> None: + """ + Download (pull) one or more files/directories from the device to the host. + + :param sources: Paths on the device to be pulled. + :param dest: Destination path on the host. + :param timeout: Max time in seconds for each file. If exceeded, an error is raised. + """ return self._push_pull('pull', sources, dest, timeout) - def _push_pull(self, action, sources, dest, timeout): - sources = list(sources) - paths = sources + [dest] + def _push_pull(self, action: Union[Literal['push'], Literal['pull']], + sources: List[str], dest: str, timeout: Optional[int]) -> None: + """ + Internal helper that runs 'adb push' or 'adb pull' with optional timeouts + and transfer polling. + """ + sourcesList: List[str] = list(sources) + pathsList: List[str] = sourcesList + [dest] # Quote twice to avoid expansion by host shell, then ADB globbing - do_quote = lambda x: quote(glob.escape(x)) - paths = ' '.join(map(do_quote, paths)) + do_quote: Callable[[str], str] = lambda x: quote(glob.escape(x)) + paths: str = ' '.join(map(do_quote, pathsList)) command = "{} {}".format(action, paths) if timeout: @@ -359,8 +590,23 @@ def _push_pull(self, action, sources, dest, timeout): popen.communicate() # pylint: disable=unused-argument - def execute(self, command, timeout=None, check_exit_code=False, - as_root=False, strip_colors=True, will_succeed=False): + def execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, + check_exit_code: bool = False, as_root: Optional[bool] = False, + strip_colors: bool = True, will_succeed: bool = False) -> str: + """ + Execute a command on the device via ``adb shell``. + + :param command: The command line to run (string or SubprocessCommand). + :param timeout: Time in seconds before forcibly terminating the command. None for no limit. + :param check_exit_code: If True, raise an error if the command's exit code != 0. + :param as_root: If True, attempt to run it as root if available. + :param strip_colors: If True, strip any ANSI colors (unused in this method). + :param will_succeed: If True, treat an error as transient rather than stable. + :returns: The command's output (combined stdout+stderr). + :raises TargetTransientCalledProcessError: If the command fails but is flagged as transient. + :raises TargetStableCalledProcessError: If the command fails in a stable (non-transient) way. + :raises TargetStableError: If there's a stable device/command error. + """ if as_root and self.connected_as_root: as_root = False try: @@ -380,13 +626,35 @@ def execute(self, command, timeout=None, check_exit_code=False, else: raise - def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + def background(self, command: 'SubprocessCommand', stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, as_root: Optional[bool] = False) -> AdbBackgroundCommand: + """ + Launch a background command via adb shell and return a handle to manage it. + + :param command: The command to run on the device. + :param stderr: File descriptor or special value for stderr. + :param as_root: If True, attempt to run the command as root. + :returns: A handle to the background command. + + .. note:: This **will block the connection** until the command completes. + """ if as_root and self.connected_as_root: as_root = False - bg_cmd = self._background(command, stdout, stderr, as_root) + bg_cmd: AdbBackgroundCommand = self._background(command, stdout, stderr, as_root) return bg_cmd - def _background(self, command, stdout, stderr, as_root): + def _background(self, command: 'SubprocessCommand', stdout: int, + stderr: int, as_root: Optional[bool]) -> AdbBackgroundCommand: + """ + Helper method to run a background shell command via adb. + + :param command: Shell command to run. + :param stdout: Location for stdout writes. + :param stderr: Location for stderr writes. + :param as_root: If True, run as root if possible. + :returns: An AdbBackgroundCommand object. + :raises Exception: If PID detection fails or no valid device is set. + """ def make_init_kwargs(command): adb_popen, pid = adb_background_shell(self, command, stdout, stderr, as_root) return dict( @@ -402,7 +670,15 @@ def make_init_kwargs(command): ) return bg_cmd - def _close(self): + def _close(self) -> None: + """ + Close the connection to the device. The :class:`Connection` object should not + be used after this method is called. There is no way to reopen a previously + closed connection, a new connection object should be created instead. + """ + if not hasattr(AdbConnection, "active_connections") or AdbConnection.active_connections is None: + return # Prevents AttributeError when closing a non-existent connection + lock, nr_active = AdbConnection.active_connections with lock: nr_active[self.device] -= 1 @@ -415,13 +691,24 @@ def _close(self): self.adb_root(enable=self._restore_to_adb_root) adb_disconnect(self.device, self.adb_server, self.adb_port) - def cancel_running_command(self): + def cancel_running_command(self) -> None: + """ + Cancel a running command (previously started with :func:`background`) and free up the connection. + It is valid to call this if the command has already terminated (or if no + command was issued), in which case this is a no-op. + """ # adbd multiplexes commands so that they don't interfer with each # other, so there is no need to explicitly cancel a running command # before the next one can be issued. pass def adb_root(self, enable=True): + """ + Enable or disable root mode for this device connection. + + :param enable: True to enable root, False to unroot. + :raises AdbRootError: If multiple connections are active or device disallows root. + """ self._adb_root(enable=enable) def _adb_root(self, enable): @@ -451,33 +738,48 @@ def is_rooted(out): AdbConnection._connected_as_root[self.device] = enable return was_rooted - def wait_for_device(self, timeout=30): + def wait_for_device(self, timeout: Optional[int] = 30) -> None: + """ + Block until the device is available for commands, up to a specified timeout. + + :param timeout: Time in seconds before giving up. + """ adb_command(self.device, 'wait-for-device', timeout, self.adb_server, self.adb_port) - def reboot_bootloader(self, timeout=30): + def reboot_bootloader(self, timeout: int = 30) -> None: + """ + Reboot the device into its bootloader (fastboot) mode. + + :param timeout: Seconds to wait for the reboot command to be accepted. + """ adb_command(self.device, 'reboot-bootloader', timeout, self.adb_server, self.adb_port) # Again, we need to handle boards where the default output format from ls is # single column *and* boards where the default output is multi-column. # We need to do this purely because the '-1' option causes errors on older # versions of the ls tool in Android pre-v7. - def _setup_ls(self): + def _setup_ls(self) -> None: + """ + Detect whether 'ls -1' is supported, falling back to plain 'ls' on older devices. + """ command = "shell '(ls -1); echo \"\n$?\"'" try: output = adb_command(self.device, command, timeout=self.timeout, adb_server=self.adb_server, adb_port=self.adb_port) except subprocess.CalledProcessError as e: raise HostError( - 'Failed to set up ls command on Android device. Output:\n' - + e.output) - lines = output.splitlines() - retval = lines[-1].strip() + 'Failed to set up ls command on Android device. Output:\n' + e.output) + lines: List[str] = output.splitlines() + retval: str = lines[-1].strip() if int(retval) == 0: self.ls_command = 'ls -1' else: self.ls_command = 'ls' logger.debug("ls command is set to {}".format(self.ls_command)) - def _setup_su(self): + def _setup_su(self) -> None: + """ + Attempt to confirm if 'su -c' is required or a simpler 'su' approach works. + """ # Already root, nothing to do if self.connected_as_root: return @@ -492,26 +794,49 @@ def _setup_su(self): logger.debug("su command is set to {}".format(quote(self.su_cmd))) -def fastboot_command(command, timeout=None, device=None): - target = '-s {}'.format(quote(device)) if device else '' - bin_ = _ANDROID_ENV.get_env('fastboot') - full_command = f'{bin} {target} {command}' +def fastboot_command(command: str, timeout: Optional[int] = None, + device: Optional[str] = None) -> str: + """ + Execute a fastboot command, optionally targeted at a specific device. + + :param command: The fastboot subcommand (e.g. 'devices', 'flash'). + :param timeout: Time in seconds before the command fails. + :param device: Fastboot device name. If None, assumes a single device or environment default. + :returns: Combined stdout+stderr output from the fastboot command. + :raises HostError: If the command fails or returns an error. + """ + target: str = '-s {}'.format(quote(device)) if device else '' + bin_: str = cast(str, _ANDROID_ENV.get_env('fastboot')) + full_command: str = f'{bin_} {target} {command}' logger.debug(full_command) output, _ = check_output(full_command, timeout, shell=True) return output -def fastboot_flash_partition(partition, path_to_image): - command = 'flash {} {}'.format(quote(partition), quote(path_to_image)) +def fastboot_flash_partition(partition: str, path_to_image: str) -> None: + """ + Execute 'fastboot flash ' to flash a file + onto a specific partition of the device. + + :param partition: The device partition to flash (e.g. "boot", "system"). + :param path_to_image: Full path to the image file on the host. + :raises HostError: If fastboot fails or device is not in fastboot mode. + """ + command: str = 'flash {} {}'.format(quote(partition), quote(path_to_image)) fastboot_command(command) -def adb_get_device(timeout=None, adb_server=None, adb_port=None): +def adb_get_device(timeout: Optional[int] = None, adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> str: """ - Returns the serial number of a connected android device. - - If there are more than one device connected to the machine, or it could not - find any device connected, :class:`devlib.exceptions.HostError` is raised. + Attempt to auto-detect a single connected device. If multiple or none are found, + raise an error. + + :param timeout: Maximum time to wait for device detection, or None for no limit. + :param adb_server: Optional custom server host. + :param adb_port: Optional custom server port. + :returns: The device serial number or IP:port if exactly one device is found. + :raises HostError: If zero or more than one devices are connected. """ # TODO this is a hacky way to issue a adb command to all listed devices @@ -523,67 +848,98 @@ def adb_get_device(timeout=None, adb_server=None, adb_port=None): # a list of the devices sperated by new line # The last line is a blank new line. in otherwords, if there is a device found # then the output length is 2 + (1 for each device) - start = time.time() + start: float = time.time() while True: - output = adb_command(None, "devices", adb_server=adb_server, adb_port=adb_port).splitlines() # pylint: disable=E1103 - output_length = len(output) + output: List[str] = adb_command(None, "devices", adb_server=adb_server, adb_port=adb_port).splitlines() # pylint: disable=E1103 + output_length: int = len(output) if output_length == 3: # output[1] is the 2nd line in the output which has the device name # Splitting the line by '\t' gives a list of two indexes, which has # device serial in 0 number and device type in 1. return output[1].split('\t')[0] elif output_length > 3: - message = '{} Android devices found; either explicitly specify ' +\ - 'the device you want, or make sure only one is connected.' + message: str = '{} Android devices found; either explicitly specify ' +\ + 'the device you want, or make sure only one is connected.' raise HostError(message.format(output_length - 2)) else: - if timeout < time.time() - start: + if timeout is not None and timeout < time.time() - start: raise HostError('No device is connected and available') time.sleep(1) -def adb_connect(device, timeout=None, attempts=MAX_ATTEMPTS, adb_server=None, adb_port=None): - tries = 0 - output = None +def adb_connect(device: Optional[str], timeout: Optional[int] = None, + attempts: int = MAX_ATTEMPTS, adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> None: + """ + Connect to an ADB-over-IP device or ensure a USB device is listed. Re-tries + until success or attempts are exhausted. + + :param device: The device name, if "." in it, assumes IP-based device. + :param timeout: Time in seconds for each attempt before giving up. + :param attempts: Number of times to retry connecting 10 seconds apart. + :param adb_server: Optional ADB server host. + :param adb_port: Optional ADB server port. + :raises HostError: If connection fails after all attempts. + """ + tries: int = 0 + output: Optional[str] = None while tries <= attempts: tries += 1 if device: - if "." in device: # Connect is required only for ADB-over-IP + if "." in device: # Connect is required only for ADB-over-IP # ADB does not automatically remove a network device from it's # devices list when the connection is broken by the remote, so the # adb connection may have gone "stale", resulting in adb blocking # indefinitely when making calls to the device. To avoid this, # always disconnect first. adb_disconnect(device, adb_server, adb_port) - adb_cmd = get_adb_command(None, 'connect', adb_server, adb_port) - command = '{} {}'.format(adb_cmd, quote(device)) + adb_cmd: str = get_adb_command(None, 'connect', adb_server, adb_port) + command: str = '{} {}'.format(adb_cmd, quote(device)) logger.debug(command) output, _ = check_output(command, shell=True, timeout=timeout) if _ping(device, adb_server, adb_port): break time.sleep(10) else: # did not connect to the device - message = f'Could not connect to {device or "a device"} at {adb_server}:{adb_port}' + message: str = f'Could not connect to {device or "a device"} at {adb_server}:{adb_port}' if output: message += f'; got: {output}' raise HostError(message) -def adb_disconnect(device, adb_server=None, adb_port=None): +def adb_disconnect(device: Optional[str], adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> None: + """ + Issue an 'adb disconnect' for the specified device, if relevant. + + :param device: Device serial or IP:port. If None or no IP in the name, no action is taken. + :param adb_server: Custom ADB server host if used. + :param adb_port: Custom ADB server port if used. + """ if not device: return if ":" in device and device in adb_list_devices(adb_server, adb_port): - adb_cmd = get_adb_command(None, 'disconnect', adb_server, adb_port) - command = "{} {}".format(adb_cmd, device) + adb_cmd: str = get_adb_command(None, 'disconnect', adb_server, adb_port) + command: str = "{} {}".format(adb_cmd, device) logger.debug(command) - retval = subprocess.call(command, stdout=subprocess.DEVNULL, shell=True) + retval: int = subprocess.call(command, stdout=subprocess.DEVNULL, shell=True) if retval: raise TargetTransientError('"{}" returned {}'.format(command, retval)) -def _ping(device, adb_server=None, adb_port=None): - adb_cmd = get_adb_command(device, 'shell', adb_server, adb_port) - command = "{} {}".format(adb_cmd, quote('ls /data/local/tmp > /dev/null')) +def _ping(device: Optional[str], adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> bool: + """ + Ping the specified device by issuing a trivial command (ls /data/local/tmp). + If it fails, the device is presumably unreachable or offline. + + :param device: The device name or IP:port. + :param adb_server: ADB server host, if any. + :param adb_port: ADB server port, if any. + :returns: True if the device responded, otherwise False. + """ + adb_cmd: str = get_adb_command(device, 'shell', adb_server, adb_port) + command: str = "{} {}".format(adb_cmd, quote('ls /data/local/tmp > /dev/null')) logger.debug(command) try: subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True) @@ -595,23 +951,39 @@ def _ping(device, adb_server=None, adb_port=None): # pylint: disable=too-many-locals -def adb_shell(device, command, timeout=None, check_exit_code=False, - as_root=False, adb_server=None, adb_port=None, su_cmd='su -c {}'): # NOQA - +def adb_shell(device: str, command: 'SubprocessCommand', timeout: Optional[int] = None, + check_exit_code: bool = False, as_root: Optional[bool] = False, adb_server: Optional[str] = None, + adb_port:Optional[int]=None, su_cmd:str='su -c {}') -> str: # NOQA + """ + Run a command in 'adb shell' mode, capturing both stdout/stderr. Uses a technique + to capture the actual command's exit code so that we can detect non-zero exit + reliably on older ADB combos. + + :param device: The device serial or IP:port. + :param command: The command line to run inside 'adb shell'. + :param timeout: Time in seconds to wait for the command, or None for no limit. + :param check_exit_code: If True, raise an error if the command exit code is nonzero. + :param as_root: If True, prepend an su command to run as root if supported. + :param adb_server: Optional custom adb server IP/name. + :param adb_port: Optional custom adb server port. + :param su_cmd: Command template to wrap as root, e.g. 'su -c {}'. + :returns: The combined stdout from the command (minus the exit code). + :raises TargetStableError: If there's an error with the command or exit code extraction fails. + """ # On older combinations of ADB/Android versions, the adb host command always # exits with 0 if it was able to run the command on the target, even if the # command failed (https://code.google.com/p/android/issues/detail?id=3254). # Homogenise this behaviour by running the command then echoing the exit # code of the executed command itself. - command = r'({}); echo "\n$?"'.format(command) + command = r'({}); echo "\n$?"'.format(cast(str, command)) command = su_cmd.format(quote(command)) if as_root else command command = ('shell', command) parts, env = _get_adb_parts(command, device, adb_server, adb_port, quote_adb=False) env = {**os.environ, **env} - logger.debug(' '.join(quote(part) for part in parts)) + logger.debug(' '.join(quote(cast(str, part)) for part in parts)) try: - raw_output, error = check_output(parts, timeout, shell=False, env=env) + raw_output, error = check_output(cast('SubprocessCommand', parts), timeout, shell=False, env=env) except subprocess.CalledProcessError as e: raise TargetStableError(str(e)) @@ -629,10 +1001,10 @@ def adb_shell(device, command, timeout=None, check_exit_code=False, exit_code = exit_code.strip() re_search = AM_START_ERROR.findall(output) if exit_code.isdigit(): - exit_code = int(exit_code) - if exit_code: + exit_code_i = int(exit_code) + if exit_code_i: raise subprocess.CalledProcessError( - exit_code, + exit_code_i, command, output, error, @@ -654,11 +1026,23 @@ def adb_shell(device, command, timeout=None, check_exit_code=False, return '\n'.join(x for x in (output, error) if x) -def adb_background_shell(conn, command, +def adb_background_shell(conn: AdbConnection, command: 'SubprocessCommand', stdout=subprocess.PIPE, stderr=subprocess.PIPE, - as_root=False): - """Runs the specified command in a subprocess, returning the the Popen object.""" + as_root: Optional[bool] = False) -> Tuple['Popen', int]: + """ + Run a command in the background on the device via ADB shell, returning a Popen + object and an integer PID. This approach uses SIGSTOP to freeze the shell + while the PID is identified. + + :param conn: The AdbConnection managing the device. + :param command: A shell command to run in the background. + :param stdout: File descriptor for stdout, default is pipe. + :param stderr: File descriptor for stderr, default is pipe. + :param as_root: If True, attempt to run under su if root is available. + :returns: A tuple of (popen_obj, pid). + :raises TargetTransientError: If the PID cannot be identified after retries. + """ device = conn.device adb_server = conn.adb_server adb_port = conn.adb_port @@ -667,12 +1051,12 @@ def adb_background_shell(conn, command, stdout, stderr, command = redirect_streams(stdout, stderr, command) if as_root: - command = f'{busybox} printf "%s" {quote(command)} | su' + command = f'{busybox} printf "%s" {quote(cast(str, command))} | su' - def with_uuid(cmd): + def with_uuid(cmd: str) -> Tuple[str, str]: # Attach a unique UUID to the command line so it can be looked for # without any ambiguity with ps - uuid_ = uuid.uuid4().hex + uuid_: str = uuid.uuid4().hex # Unset the var, since not all connection types set it. This will avoid # anyone depending on that value. cmd = f'DEVLIB_CMD_UUID={uuid_}; unset DEVLIB_CMD_UUID; {cmd}' @@ -682,16 +1066,16 @@ def with_uuid(cmd): return (uuid_, cmd) # Freeze the command with SIGSTOP to avoid racing with PID detection. - command = f"{busybox} kill -STOP $$ && exec {busybox} sh -c {quote(command)}" + command = f"{busybox} kill -STOP $$ && exec {busybox} sh -c {quote(cast(str, command))}" command_uuid, command = with_uuid(command) - adb_cmd = get_adb_command(device, 'shell', adb_server, adb_port) - full_command = f'{adb_cmd} {quote(command)}' + adb_cmd: str = get_adb_command(device, 'shell', adb_server, adb_port) + full_command: str = f'{adb_cmd} {quote(cast(str, command))}' logger.debug(full_command) - p = subprocess.Popen(full_command, stdout=stdout, stderr=stderr, stdin=subprocess.PIPE, shell=True) + p: 'Popen' = subprocess.Popen(full_command, stdout=stdout, stderr=stderr, stdin=subprocess.PIPE, shell=True) # Out of band PID lookup, to avoid conflicting needs with stdout redirection - grep_cmd = f'{busybox} grep {quote(command_uuid)}' + grep_cmd: str = f'{busybox} grep {quote(command_uuid)}' # Find the PID and release the blocked background command with SIGCONT. # We get multiple PIDs: # * One from the grep command itself, but we remove it with another grep command. @@ -700,15 +1084,15 @@ def with_uuid(cmd): # For each of the parent layer, we issue SIGCONT as it is harmless and # avoids having to rely on PID ordering (which could be misleading if PIDs # got recycled). - find_pid = f'''pids=$({busybox} ps -A -o pid,args | {grep_cmd} | {busybox} grep -v {quote(grep_cmd)} | {busybox} awk '{{print $1}}') && {busybox} printf "%s" "$pids" && {busybox} kill -CONT $pids''' + find_pid: str = f'''pids=$({busybox} ps -A -o pid,args | {grep_cmd} | {busybox} grep -v {quote(grep_cmd)} | {busybox} awk '{{print $1}}') && {busybox} printf "%s" "$pids" && {busybox} kill -CONT $pids''' - excep = None + excep: Optional[Exception] = None for _ in range(5): try: - pids = conn.execute(find_pid, as_root=as_root) + pids: str = conn.execute(find_pid, as_root=as_root) # We choose the highest PID as the "control" PID. It actually does not # really matter which one we pick, as they are all equivalent sh -c layers. - pid = max(map(int, pids.split())) + pid: int = max(map(int, pids.split())) except TargetStableError: raise except Exception as e: @@ -718,71 +1102,130 @@ def with_uuid(cmd): else: break else: - raise TargetTransientError(f'Could not detect PID of background command: {orig_command}') from excep + raise TargetTransientError(f'Could not detect PID of background command: {cast(str, orig_command)}') from excep return (p, pid) -def adb_kill_server(timeout=30, adb_server=None, adb_port=None): + +def adb_kill_server(timeout: Optional[int] = 30, adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> None: + """ + Issue 'adb kill-server' to forcibly shut down the local ADB server. + + :param timeout: Seconds to wait for the command. + :param adb_server: Optional custom server host. + :param adb_port: Optional custom server port. + """ adb_command(None, 'kill-server', timeout, adb_server, adb_port) -def adb_list_devices(adb_server=None, adb_port=None): - output = adb_command(None, 'devices', adb_server=adb_server, adb_port=adb_port) - devices = [] + +def adb_list_devices(adb_server: Optional[str] = None, adb_port: Optional[int] = None) -> List[AdbDevice]: + """ + List all devices known to ADB by running 'adb devices'. Each line is parsed + into an :class:`AdbDevice`. + + :param adb_server: Custom ADB server hostname. + :param adb_port: Custom ADB server port. + :returns: A list of AdbDevice objects describing connected devices. + """ + output: str = adb_command(None, 'devices', adb_server=adb_server, adb_port=adb_port) + devices: List[AdbDevice] = [] for line in output.splitlines(): - parts = [p.strip() for p in line.split()] + parts: List[str] = [p.strip() for p in line.split()] if len(parts) == 2: devices.append(AdbDevice(*parts)) return devices -def _get_adb_parts(command, device=None, adb_server=None, adb_port=None, quote_adb=True): +def _get_adb_parts(command: Union[Tuple[str], Tuple[str, str]], device: Optional[str] = None, + adb_server: Optional[str] = None, adb_port: Optional[int] = None, + quote_adb: bool = True) -> Tuple[PartsType, Dict[str, str]]: + """ + Build a tuple of adb command parts, plus environment variables. + + :param command: A tuple of command parts (like ('shell', 'ls')). + :param device: The device name or None if no device param used. + :param adb_server: Host/IP of custom adb server if set. + :param adb_port: Port of custom adb server if set. + :param quote_adb: Whether to quote the server/port args. + :returns: A tuple containing the command parts, plus a dict of env updates. + """ _quote = quote if quote_adb else lambda x: x - parts = ( - _ANDROID_ENV.get_env('adb'), + + parts: PartsType = ( + cast(str, _ANDROID_ENV.get_env('adb')), *(('-H', _quote(adb_server)) if adb_server is not None else ()), *(('-P', _quote(str(adb_port))) if adb_port is not None else ()), *(('-s', _quote(device)) if device is not None else ()), *command, ) - env = {'LC_ALL': 'C'} + env: Dict[str, str] = {'LC_ALL': 'C'} return (parts, env) -def get_adb_command(device, command, adb_server=None, adb_port=None): - parts, env = _get_adb_parts((command,), device, adb_server, adb_port, quote_adb=True) - env = [quote(f'{name}={val}') for name, val in sorted(env.items())] - parts = [*env, *parts] - return ' '.join(parts) +def get_adb_command(device: Optional[str], command: str, adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> str: + """ + Build a single-string 'adb' command that can be run in a host shell. + + :param device: The device serial or IP:port, or None to skip. + :param command: The subcommand, e.g. 'shell', 'push', etc. + :param adb_server: Optional custom server address. + :param adb_port: Optional custom server port. + :returns: A fully expanded command string including environment variables for LC_ALL. + """ + partstemp, envtemp = _get_adb_parts((command,), device, adb_server, adb_port, quote_adb=True) + env: List[str] = [quote(f'{name}={val}') for name, val in sorted(envtemp.items())] + parts = [*env, *partstemp] + return ' '.join(cast(List[str], parts)) -def adb_command(device, command, timeout=None, adb_server=None, adb_port=None): - full_command = get_adb_command(device, command, adb_server, adb_port) +def adb_command(device: Optional[str], command: str, timeout: Optional[int] = None, + adb_server: Optional[str] = None, adb_port: Optional[int] = None) -> str: + """ + Build and run an 'adb' command synchronously, returning its combined output. + + :param device: Device name, or None if only one or no device is expected. + :param command: A subcommand or subcommand + arguments (e.g. 'push file /sdcard/'). + :param timeout: Seconds to wait for completion (None for no limit). + :param adb_server: Custom ADB server host if needed. + :param adb_port: Custom ADB server port if needed. + :returns: The command's output as a decoded string. + :raises HostError: If the command fails or returns non-zero. + """ + full_command: str = get_adb_command(device, command, adb_server, adb_port) logger.debug(full_command) output, _ = check_output(full_command, timeout, shell=True) return output -def adb_command_popen(device, conn, command, adb_server=None, adb_port=None): +def adb_command_popen(device: Optional[str], conn: AdbConnection, command: str, + adb_server: Optional[str] = None, adb_port: Optional[int] = None) -> 'Popen': command = get_adb_command(device, command, adb_server, adb_port) logger.debug(command) popen = get_subprocess(command, shell=True) return popen -def grant_app_permissions(target, package): +def grant_app_permissions(target: 'AndroidTarget', package: str) -> None: """ - Grant an app all the permissions it may ask for + Grant all requested permissions to an installed app package by parsing the + 'dumpsys package' output. + + :param target: The Android target on which the package is installed. + :param package: The package name (e.g., "com.example.app"). + :raises TargetStableError: If permission granting fails or the package is invalid. """ - dumpsys = target.execute('dumpsys package {}'.format(package)) + dumpsys: str = target.execute('dumpsys package {}'.format(package)) - permissions = re.search( + permissions: Optional[Match[str]] = re.search( r'requested permissions:\s*(?P(android.permission.+\s*)+)', dumpsys ) if permissions is None: return - permissions = permissions.group('permissions').replace(" ", "").splitlines() + permissions_list: List[str] = permissions.group('permissions').replace(" ", "").splitlines() - for permission in permissions: + for permission in permissions_list: try: target.execute('pm grant {} {}'.format(package, permission)) except TargetStableError: @@ -794,10 +1237,18 @@ class _AndroidEnvironment: # Make the initialization lazy so that we don't trigger an exception if the # user imports the module (directly or indirectly) without actually using # anything from it + """ + Lazy-initialized environment data for Android tools (adb, aapt, etc.), + constructed from ANDROID_HOME or by scanning the system PATH. + """ @property @functools.lru_cache(maxsize=None) - def env(self): - android_home = os.getenv('ANDROID_HOME') + def env(self) -> Android_Env_Type: + """ + :returns: The discovered Android environment mapping with keys like 'adb', 'aapt', etc. + :raises HostError: If we cannot find a suitable ANDROID_HOME or 'adb' in PATH. + """ + android_home: Optional[str] = os.getenv('ANDROID_HOME') if android_home: env = self._from_android_home(android_home) else: @@ -805,52 +1256,82 @@ def env(self): return env - def get_env(self, name): + def get_env(self, name: Android_Env_TypeKeys) -> Optional[Union[str, int]]: + """ + Retrieve a specific environment field, such as 'adb', 'aapt', or 'build_tools'. + + :param name: Name of the environment key. + :returns: The value if found, else None. + """ return self.env[name] @classmethod - def _from_android_home(cls, android_home): + def _from_android_home(cls, android_home: str) -> Android_Env_Type: + """ + Build environment info from ANDROID_HOME. + + :param android_home: Path to Android SDK root. + :returns: Dictionary of environment settings. + """ logger.debug('Using ANDROID_HOME from the environment.') platform_tools = os.path.join(android_home, 'platform-tools') - return { + return cast(Android_Env_Type, { 'android_home': android_home, 'platform_tools': platform_tools, 'adb': os.path.join(platform_tools, 'adb'), 'fastboot': os.path.join(platform_tools, 'fastboot'), **cls._init_common(android_home) - } + }) @classmethod - def _from_adb(cls): + def _from_adb(cls) -> Android_Env_Type: + """ + Attempt to derive environment info by locating 'adb' on the system PATH. + + :returns: A dictionary of environment settings. + :raises HostError: If 'adb' is not found in PATH. + """ adb_path = which('adb') if adb_path: logger.debug('Discovering ANDROID_HOME from adb path.') platform_tools = os.path.dirname(adb_path) android_home = os.path.dirname(platform_tools) - return { + return cast(Android_Env_Type, { 'android_home': android_home, 'platform_tools': platform_tools, 'adb': adb_path, 'fastboot': which('fastboot'), **cls._init_common(android_home) - } + }) else: raise HostError('ANDROID_HOME is not set and adb is not in PATH. ' 'Have you installed Android SDK?') @classmethod - def _init_common(cls, android_home): + def _init_common(cls, android_home: str) -> BuildToolsInfo: + """ + Discover build tools, aapt, etc., from an Android SDK layout. + + :param android_home: Android SDK root path. + :returns: Partial dictionary with keys like 'build_tools', 'aapt', 'aapt_version'. + """ logger.debug(f'ANDROID_HOME: {android_home}') build_tools = cls._discover_build_tools(android_home) - return { + return cast(BuildToolsInfo, { 'build_tools': build_tools, **cls._discover_aapt(build_tools) - } + }) @staticmethod - def _discover_build_tools(android_home): + def _discover_build_tools(android_home: str) -> Optional[str]: + """ + Attempt to locate the build-tools directory under android_home. + + :param android_home: Path to the SDK. + :returns: Path to build-tools if found, else None. + """ build_tools = os.path.join(android_home, 'build-tools') if os.path.isdir(build_tools): return build_tools @@ -858,7 +1339,13 @@ def _discover_build_tools(android_home): return None @staticmethod - def _check_supported_aapt2(binary): + def _check_supported_aapt2(binary: str) -> bool: + """ + Check if a given 'aapt2' binary supports 'dump badging'. + + :param binary: Path to the aapt2 binary. + :returns: True if the binary appears to support the 'badging' command, else False. + """ # At time of writing the version argument of aapt2 is not helpful as # the output is only a placeholder that does not distinguish between versions # with and without support for badging. Unfortunately aapt has been @@ -867,32 +1354,45 @@ def _check_supported_aapt2(binary): # Try to execute the badging command and check if we get an expected error # message as opposed to an unknown command error to determine if we have a # suitable version. - result = subprocess.run([str(binary), 'dump', 'badging'], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, universal_newlines=True) + """ + check if aapt2 is supported + """ + result: 'CompletedProcess' = subprocess.run([str(binary), 'dump', 'badging'], + stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, + universal_newlines=True) supported = bool(AAPT_BADGING_OUTPUT.search(result.stderr)) - msg = 'Found a {} aapt2 binary at: {}' + msg: str = 'Found a {} aapt2 binary at: {}' logger.debug(msg.format('supported' if supported else 'unsupported', binary)) return supported @classmethod - def _discover_aapt(cls, build_tools): + def _discover_aapt(cls, build_tools: Optional[str]) -> Dict[str, Optional[Union[str, int]]]: + """ + Attempt to find 'aapt2' or 'aapt' in build-tools (or PATH fallback). + Prefers aapt2 if available. + + :param build_tools: Path to the build-tools directory or None if unknown. + :returns: A dictionary with 'aapt' and 'aapt_version' keys. + :raises HostError: If neither aapt nor aapt2 is found. + """ if build_tools: - def find_aapt2(version): + def find_aapt2(version: str) -> Tuple[Optional[int], Optional[str]]: path = os.path.join(build_tools, version, 'aapt2') if os.path.isfile(path) and cls._check_supported_aapt2(path): return (2, path) else: return (None, None) - def find_aapt(version): - path = os.path.join(build_tools, version, 'aapt') + def find_aapt(version: str) -> Tuple[Optional[int], Optional[str]]: + path: str = os.path.join(build_tools, version, 'aapt') if os.path.isfile(path): return (1, path) else: return (None, None) - versions = os.listdir(build_tools) - found = ( + versions: List[str] = os.listdir(build_tools) + found: Generator[Tuple[str, Tuple[Optional[int], Optional[str]]]] = ( (version, finder(version)) for version in reversed(sorted(versions)) for finder in (find_aapt2, find_aapt) @@ -907,7 +1407,7 @@ def find_aapt(version): ) # Try detecting aapt2 and aapt from PATH - aapt2_path = which('aapt2') + aapt2_path: Optional[str] = which('aapt2') aapt_path = which('aapt') if aapt2_path and cls._check_supported_aapt2(aapt2_path): return dict( @@ -928,33 +1428,37 @@ class LogcatMonitor(object): Helper class for monitoring Anroid's logcat :param target: Android target to monitor - :type target: :class:`AndroidTarget` :param regexps: List of uncompiled regular expressions to filter on the device. Logcat entries that don't match any will not be seen. If omitted, all entries will be sent to host. - :type regexps: list(str) """ @property - def logfile(self): + def logfile(self) -> Optional[Union['TextIOWrapper', '_TemporaryFileWrapper[str]']]: + """ + Return the file-like object that logcat is writing to, if any. + + :returns: The log file or None. + """ return self._logfile - def __init__(self, target, regexps=None, logcat_format=None): + def __init__(self, target: 'AndroidTarget', regexps: Optional[List[str]] = None, + logcat_format: Optional[str] = None): super(LogcatMonitor, self).__init__() self.target = target self._regexps = regexps self._logcat_format = logcat_format - self._logcat = None - self._logfile = None + self._logcat: Optional[spawn] = None + self._logfile: Optional[Union['TextIOWrapper', '_TemporaryFileWrapper[str]']] = None - def start(self, outfile=None): + def start(self, outfile: Optional[str] = None) -> None: """ - Start logcat and begin monitoring + Begin capturing logcat output. If outfile is given, logcat lines are + appended there; otherwise, a temporary file is used. - :param outfile: Optional path to file to store all logcat entries - :type outfile: str + :param outfile: A path to a file on the host, or None for a temporary file. """ if outfile: self._logfile = open(outfile, 'w') @@ -963,11 +1467,11 @@ def start(self, outfile=None): self.target.clear_logcat() - logcat_cmd = 'logcat' + logcat_cmd: str = 'logcat' # Join all requested regexps with an 'or' if self._regexps: - regexp = '{}'.format('|'.join(self._regexps)) + regexp: str = '{}'.format('|'.join(self._regexps)) if len(self._regexps) > 1: regexp = '({})'.format(regexp) # Logcat on older version of android do not support the -e argument @@ -980,26 +1484,41 @@ def start(self, outfile=None): if self._logcat_format: logcat_cmd = "{} -v {}".format(logcat_cmd, quote(self._logcat_format)) - logcat_cmd = get_adb_command(self.target.conn.device, logcat_cmd, self.target.adb_server, self.target.adb_port) - + logcat_cmd = get_adb_command(self.target.conn.device, + logcat_cmd, self.target.adb_server, + self.target.adb_port) if isinstance(self.target.conn, AdbConnection) else '' + logcat_cmd = f"/bin/bash -c '{logcat_cmd}'" logger.debug('logcat command ="{}"'.format(logcat_cmd)) self._logcat = pexpect.spawn(logcat_cmd, logfile=self._logfile, encoding='utf-8') - def stop(self): + def stop(self) -> None: + """ + Stop capturing logcat and close the log file if applicable. + """ self.flush_log() - self._logcat.terminate() - self._logfile.close() + if self._logcat: + self._logcat.terminate() + if self._logfile: + self._logfile.close() - def get_log(self): + def get_log(self) -> List[str]: """ - Return the list of lines found by the monitor + Retrieve all captured lines from the log so far. + + :returns: A list of log lines from the log file. """ self.flush_log() + if self._logfile: + with open(self._logfile.name) as fh: + return [line for line in fh] + else: + return [] - with open(self._logfile.name) as fh: - return [line for line in fh] - - def flush_log(self): + def flush_log(self) -> None: + """ + Force-read all pending data from the logcat pexpect spawn to ensure it's + written to the logfile. Prevents missed lines if pexpect hasn't pulled them yet. + """ # Unless we tell pexect to 'expect' something, it won't read from # logcat's buffer or write into our logfile. We'll need to force it to # read any pending logcat output. @@ -1009,7 +1528,9 @@ def flush_log(self): # This will read up to read_size bytes, but only those that are # already ready (i.e. it won't block). If there aren't any bytes # already available it raises pexpect.TIMEOUT. - buf = self._logcat.read_nonblocking(read_size, timeout=0) + buf: str = '' + if self._logcat: + buf = self._logcat.read_nonblocking(read_size, timeout=0) # We can't just keep calling read_nonblocking until we get a # pexpect.TIMEOUT (i.e. until we don't find any available @@ -1030,33 +1551,39 @@ def flush_log(self): # printed anything since pexpect last read from its buffer. break - def clear_log(self): - with open(self._logfile.name, 'w') as _: - pass + def clear_log(self) -> None: + """ + Erase current content of the log file so subsequent calls to get_log() + won't return older lines. + """ + if self._logfile: + with open(self._logfile.name, 'w') as _: + pass - def search(self, regexp): + def search(self, regexp: str) -> List[str]: """ - Search a line that matches a regexp in the logcat log - Return immediatly + Search the captured lines for matches of the given regexp. + + :param regexp: A regular expression pattern. + :returns: All matching lines found so far. """ return [line for line in self.get_log() if re.match(regexp, line)] - def wait_for(self, regexp, timeout=30): + def wait_for(self, regexp: str, timeout: Optional[int] = 30) -> List[str]: """ Search a line that matches a regexp in the logcat log Wait for it to appear if it's not found :param regexp: regexp to search - :type regexp: str :param timeout: Timeout in seconds, before rasing RuntimeError. ``None`` means wait indefinitely - :type timeout: number :returns: List of matched strings + :raises RuntimeError: If the regex is not found within ``timeout`` seconds. """ - log = self.get_log() - res = [line for line in log if re.match(regexp, line)] + log: List[str] = self.get_log() + res: List[str] = [line for line in log if re.match(regexp, line)] # Found some matches, return them if res: @@ -1064,15 +1591,16 @@ def wait_for(self, regexp, timeout=30): # Store the number of lines we've searched already, so we don't have to # re-grep them after 'expect' returns - next_line_num = len(log) + next_line_num: int = len(log) try: - self._logcat.expect(regexp, timeout=timeout) + if self._logcat: + self._logcat.expect(regexp, timeout=timeout) except pexpect.TIMEOUT: raise RuntimeError('Logcat monitor timeout ({}s)'.format(timeout)) return [line for line in self.get_log()[next_line_num:] if re.match(regexp, line)] -_ANDROID_ENV = _AndroidEnvironment() +_ANDROID_ENV = _AndroidEnvironment() diff --git a/devlib/utils/annotation_helpers.py b/devlib/utils/annotation_helpers.py new file mode 100644 index 000000000..cee651a8c --- /dev/null +++ b/devlib/utils/annotation_helpers.py @@ -0,0 +1,72 @@ +# Copyright 2025 ARM Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Helpers to annotate the code + +""" +import sys +from typing import Union, Sequence, Optional +from typing_extensions import NotRequired, LiteralString, TYPE_CHECKING, TypedDict +if TYPE_CHECKING: + from _typeshed import StrPath, BytesPath + from devlib.platform import Platform + from devlib.utils.android import AdbConnection + from devlib.utils.ssh import SshConnection + from devlib.host import LocalConnection + from devlib.connection import PopenBackgroundCommand, AdbBackgroundCommand, ParamikoBackgroundCommand +else: + StrPath = str + BytesPath = bytes + + +import os +if sys.version_info >= (3, 9): + SubprocessCommand = Union[ + str, bytes, os.PathLike[str], os.PathLike[bytes], + Sequence[Union[str, bytes, os.PathLike[str], os.PathLike[bytes]]]] +else: + SubprocessCommand = Union[str, bytes, os.PathLike, + Sequence[Union[str, bytes, os.PathLike]]] + +BackgroundCommand = Union['AdbBackgroundCommand', 'ParamikoBackgroundCommand', 'PopenBackgroundCommand'] + +SupportedConnections = Union['LocalConnection', 'AdbConnection', 'SshConnection'] + + +class SshUserConnectionSettings(TypedDict, total=False): + username: NotRequired[str] + password: NotRequired[str] + keyfile: NotRequired[Optional[Union[LiteralString, StrPath, BytesPath]]] + host: NotRequired[str] + port: NotRequired[int] + timeout: NotRequired[float] + platform: NotRequired['Platform'] + sudo_cmd: NotRequired[str] + strict_host_check: NotRequired[bool] + use_scp: NotRequired[bool] + poll_transfers: NotRequired[bool] + start_transfer_poll_delay: NotRequired[int] + total_transfer_timeout: NotRequired[int] + transfer_poll_period: NotRequired[int] + + +class AdbUserConnectionSettings(SshUserConnectionSettings): + device: NotRequired[str] + adb_server: NotRequired[str] + adb_port: NotRequired[int] + + +UserConnectionSettings = Union[SshUserConnectionSettings, AdbUserConnectionSettings] diff --git a/devlib/utils/asyn.py b/devlib/utils/asyn.py index dd6d42d59..20caa9ee5 100644 --- a/devlib/utils/asyn.py +++ b/devlib/utils/asyn.py @@ -1,4 +1,4 @@ -# Copyright 2013-2018 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,24 +30,46 @@ import inspect import sys import threading -from concurrent.futures import ThreadPoolExecutor -from weakref import WeakSet, WeakKeyDictionary +from concurrent.futures import ThreadPoolExecutor, Future +from weakref import WeakSet from greenlet import greenlet +from typing import (Any, Callable, TypeVar, Type, + Optional, Coroutine, Tuple, Dict, cast, Set, + List, Union, AsyncContextManager, + Iterable, Awaitable) +from collections.abc import AsyncGenerator, Generator +from asyncio import Task, AbstractEventLoop +from inspect import Signature, BoundArguments +from contextvars import Context +from queue import SimpleQueue +from threading import local + + +def create_task(awaitable: Awaitable, name: Optional[str] = None) -> Task: + """ + Create a new asyncio Task from an awaitable and set its name. - -def create_task(awaitable, name=None): + :param awaitable: A coroutine or awaitable object to schedule. + :param name: An optional name for the task. If None, attempts to use the awaitable's __qualname__. + :returns: The created asyncio Task. + """ if isinstance(awaitable, asyncio.Task): - task = awaitable + task: Task = awaitable else: - task = asyncio.create_task(awaitable) + task = asyncio.create_task(cast(Coroutine, awaitable)) if name is None: name = getattr(awaitable, '__qualname__', None) - task.name = name + task.set_name(name) return task -def _close_loop(loop): +def _close_loop(loop: Optional[AbstractEventLoop]) -> None: + """ + Close an asyncio event loop after shutting down asynchronous generators and the default executor. + + :param loop: The event loop to close, or None. + """ if loop is not None: try: loop.run_until_complete(loop.shutdown_asyncgens()) @@ -62,55 +84,71 @@ def _close_loop(loop): class AsyncManager: - def __init__(self): - self.task_tree = dict() - self.resources = dict() + """ + Manages asynchronous operations by tracking tasks and ensuring that concurrently + running asynchronous functions do not interfere with one another. - def track_access(self, access): + This manager maintains a mapping of tasks to resources and allows running tasks + concurrently while checking for overlapping resource usage. + """ + def __init__(self) -> None: + """ + Initialize the AsyncManager with empty task trees and resource maps. + """ + self.task_tree: Dict[Task, Set[Task]] = dict() + self.resources: Dict[Task, Set['ConcurrentAccessBase']] = dict() + + def track_access(self, access: 'ConcurrentAccessBase') -> None: """ Register the given ``access`` to have been handled by the current async task. :param access: Access that were done. - :type access: ConcurrentAccessBase This allows :func:`concurrently` to check that concurrent tasks did not step on each other's toes. """ try: - task = asyncio.current_task() + task: Optional[Task] = asyncio.current_task() except RuntimeError: pass else: - self.resources.setdefault(task, set()).add(access) + if task: + self.resources.setdefault(task, set()).add(access) - async def concurrently(self, awaitables): + async def concurrently(self, awaitables: Iterable[Awaitable]) -> List[Any]: """ Await concurrently for the given awaitables, and cancel them as soon as one raises an exception. + + :param awaitables: An iterable of coroutine objects to run concurrently. + :returns: A list with the results of the awaitables. + :raises Exception: Propagates the first exception encountered, canceling the others. """ - awaitables = list(awaitables) + awaitables_list: List[Awaitable] = list(awaitables) # Avoid creating asyncio.Tasks when it's not necessary, as it will # disable a the blocking path optimization of Target._execute_async() # that uses blocking calls as long as there is only one asyncio.Task # running on the event loop. - if len(awaitables) == 1: - return [await awaitables[0]] + if len(awaitables_list) == 1: + return [await awaitables_list[0]] - tasks = list(map(create_task, awaitables)) + tasks: List[Task] = list(map(create_task, awaitables_list)) - current_task = asyncio.current_task() - task_tree = self.task_tree + current_task: Optional[Task] = asyncio.current_task() + task_tree: Dict[Task, Set[Task]] = self.task_tree try: - node = task_tree[current_task] + if current_task: + node: Set[Task] = task_tree[current_task] except KeyError: - is_root_task = True + is_root_task: bool = True node = set() else: is_root_task = False - task_tree[current_task] = node + if current_task: + task_tree[current_task] = node task_tree.update({ child: set() @@ -126,8 +164,12 @@ async def concurrently(self, awaitables): raise finally: - def get_children(task): - immediate_children = task_tree[task] + def get_children(task: Task) -> frozenset[Task]: + """ + get the children of the task and their children etc and return as a + single set + """ + immediate_children: Set[Task] = task_tree[task] return frozenset( itertools.chain( [task], @@ -140,7 +182,7 @@ def get_children(task): # Get the resources created during the execution of each subtask # (directly or indirectly) - resources = { + resources: Dict[Task, frozenset['ConcurrentAccessBase']] = { task: frozenset( itertools.chain.from_iterable( self.resources.get(child, []) @@ -153,18 +195,20 @@ def get_children(task): for res1, res2 in itertools.product(resources1, resources2): if issubclass(res2.__class__, res1.__class__) and res1.overlap_with(res2): raise RuntimeError( - 'Overlapping resources manipulated in concurrent async tasks: {} (task {}) and {} (task {})'.format(res1, task1.name, res2, task2.name) + 'Overlapping resources manipulated in concurrent async tasks: {} (task {}) and {} (task {})'.format(res1, task1.get_name(), res2, task2.get_name()) ) if is_root_task: self.resources.clear() task_tree.clear() - async def map_concurrently(self, f, keys): + async def map_concurrently(self, f: Callable, keys: Any) -> Dict: """ Similar to :meth:`concurrently`, but maps the given function ``f`` on the given ``keys``. + :param f: The function to apply to each key. + :param keys: An iterable of keys. :return: A dictionary with ``keys`` as keys, and function result as values. """ @@ -175,13 +219,16 @@ async def map_concurrently(self, f, keys): )) -def compose(*coros): +def compose(*coros: Callable) -> Callable[..., Coroutine]: """ Compose coroutines, feeding the output of each as the input of the next one. ``await compose(f, g)(x)`` is equivalent to ``await f(await g(x))`` + :param coros: A variable number of coroutine functions. + :returns: A callable that, when awaited, composes the coroutines in sequence. + .. note:: In Haskell, ``compose f g h`` would be equivalent to ``f <=< g <=< h`` """ async def f(*args, **kwargs): @@ -205,8 +252,11 @@ class _AsyncPolymorphicFunction: When called, the blocking synchronous operation is called. The ```asyn`` attribute gives access to the asynchronous version of the function, and all the other attribute access will be redirected to the async function. + + :param asyn: The asynchronous version of the function. + :param blocking: The synchronous (blocking) version of the function. """ - def __init__(self, asyn, blocking): + def __init__(self, asyn: Callable[..., Awaitable], blocking: Callable[..., Any]): self.asyn = asyn self.blocking = blocking functools.update_wrapper(self, asyn) @@ -240,36 +290,45 @@ class memoized_method: * non-async methods * method already decorated with :func:`devlib.asyn.asyncf`. + :param f: The method to memoize. + .. note:: This decorator does not rely on hacks to hash unhashable data. If such input is required, it will either have to be coerced to a hashable first (e.g. converting a list to a tuple), or the code of :func:`devlib.asyn.memoized_method` will have to be updated to do so. """ - def __init__(self, f): - memo = self - - sig = inspect.signature(f) - - def bind(self, *args, **kwargs): - bound = sig.bind(self, *args, **kwargs) + def __init__(self, f: Callable): + memo: 'memoized_method' = self + + sig: Signature = inspect.signature(f) + + def bind(self, *args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], + Tuple[Any, ...], + Dict[str, Any]]: + """ + bind arguments to function signature + """ + bound: BoundArguments = sig.bind(self, *args, **kwargs) bound.apply_defaults() key = (bound.args[1:], tuple(sorted(bound.kwargs.items()))) return (key, bound.args, bound.kwargs) - def get_cache(self): + def get_cache(self) -> Dict[Tuple[Any, ...], Any]: try: - cache = self.__dict__[memo.name] + cache: Dict[Tuple[Any, ...], Any] = self.__dict__[memo.name] except KeyError: cache = {} self.__dict__[memo.name] = cache return cache - if inspect.iscoroutinefunction(f): @functools.wraps(f) - async def wrapper(self, *args, **kwargs): - cache = get_cache(self) + async def async_wrapper(self, *args: Any, **kwargs: Any) -> Any: + """ + wrapper for async functions + """ + cache: Dict[Tuple[Any, ...], Any] = get_cache(self) key, args, kwargs = bind(self, *args, **kwargs) try: return cache[key] @@ -277,9 +336,13 @@ async def wrapper(self, *args, **kwargs): x = await f(*args, **kwargs) cache[key] = x return x + self.f: Callable[..., Coroutine] = async_wrapper else: @functools.wraps(f) - def wrapper(self, *args, **kwargs): + def sync_wrapper(self, *args: Any, **kwargs: Any) -> Any: + """ + wrapper for sync functions + """ cache = get_cache(self) key, args, kwargs = bind(self, *args, **kwargs) try: @@ -288,25 +351,24 @@ def wrapper(self, *args, **kwargs): x = f(*args, **kwargs) cache[key] = x return x + self.f = sync_wrapper - - self.f = wrapper self._name = f.__name__ @property - def name(self): + def name(self) -> str: return '__memoization_cache_of_' + self._name def __call__(self, *args, **kwargs): return self.f(*args, **kwargs) - def __get__(self, obj, owner=None): + def __get__(self, obj: Optional['memoized_method'], owner: Optional[Type['memoized_method']] = None) -> Any: return self.f.__get__(obj, owner) - def __set__(self, obj, value): + def __set__(self, obj: 'memoized_method', value: Any): raise RuntimeError("Cannot monkey-patch a memoized function") - def __set_name__(self, owner, name): + def __set_name__(self, owner: Type['memoized_method'], name: str): self._name = name @@ -325,22 +387,31 @@ def __init__(self, *args, **kwargs): self.gr_context = contextvars.copy_context() @classmethod - def from_coro(cls, coro): + def from_coro(cls, coro: Coroutine) -> '_Genlet': """ Create a :class:`_Genlet` from a given coroutine, treating it as a generator. + + :param coro: The coroutine to wrap. + :returns: A _Genlet that wraps the coroutine. """ - f = lambda value: self.consume_coro(coro, value) + def f(value: Any) -> Any: + return self.consume_coro(coro, value) self = cls(f) return self - def consume_coro(self, coro, value): + def consume_coro(self, coro: Coroutine, value: Any) -> Any: """ Send ``value`` to ``coro`` then consume the coroutine, passing all its yielded actions to the enclosing :class:`_Genlet`. This allows crossing blocking calls layers as if they were async calls with `await`. + + :param coro: The coroutine to consume. + :param value: The initial value to send. + :returns: The final value returned by the coroutine. + :raises StopIteration: When the coroutine is exhausted. """ - excep = None + excep: Optional[BaseException] = None while True: try: if excep is None: @@ -351,11 +422,11 @@ def consume_coro(self, coro, value): except StopIteration as e: return e.value else: - parent = self.parent + parent: Optional[greenlet] = self.parent # Switch back to the consumer that returns the values via # send() try: - value = parent.switch(future) + value = parent.switch(future) if parent else None except BaseException as e: excep = e value = None @@ -363,17 +434,27 @@ def consume_coro(self, coro, value): excep = None @classmethod - def get_enclosing(cls): + def get_enclosing(cls) -> Optional['_Genlet']: """ Get the immediately enclosing :class:`_Genlet` in the callstack or ``None``. + + :returns: The nearest _Genlet instance in the chain, or None if not found. """ g = greenlet.getcurrent() while not (isinstance(g, cls) or g is None): g = g.parent return g - def _send_throw(self, value, excep): + def _send_throw(self, value: Optional['_Genlet'], excep: Optional[BaseException]) -> Any: + """ + helper function to do switch to another genlet or throw exception + + :param value: The value to send to the parent. + :param excep: The exception to throw in the parent, or None. + :returns: The result returned from the parent's switch. + :raises StopIteration: If the parent completes. + """ self.parent = greenlet.getcurrent() # Switch back to the function yielding values @@ -387,55 +468,78 @@ def _send_throw(self, value, excep): else: raise StopIteration(result) - def gen_send(self, x): + def gen_send(self, x: Optional['_Genlet']) -> Any: """ Similar to generators' ``send`` method. + + :param x: The value to send. + :returns: The value received from the parent. """ return self._send_throw(x, None) - def gen_throw(self, x): + def gen_throw(self, x: Optional[BaseException]): """ Similar to generators' ``throw`` method. + + :param x: The exception to throw. + :returns: The value received from the parent after handling the exception. """ return self._send_throw(None, x) class _AwaitableGenlet: """ - Wrap a coroutine with a :class:`_Genlet` and wrap that to be awaitable. + Wraps a coroutine with a :class:`_Genlet` to allow it to be awaited using + the normal 'await' syntax. + + :param coro: The coroutine to wrap. """ @classmethod - def wrap_coro(cls, coro): - async def coro_f(): + def wrap_coro(cls, coro: Coroutine) -> Coroutine: + """ + Wrap a coroutine inside an _AwaitableGenlet so that it becomes awaitable. + + :param coro: The coroutine to wrap. + :returns: An awaitable version of the coroutine. + """ + async def coro_f() -> Any: # Make sure every new task will be instrumented since a task cannot # yield futures on behalf of another task. If that were to happen, # the task B trying to do a nested yield would switch back to task # A, asking to yield on its behalf. Since the event loop would be # currently handling task B, nothing would handle task A trying to # yield on behalf of B, leading to a deadlock. - loop = asyncio.get_running_loop() + loop: AbstractEventLoop = asyncio.get_running_loop() _install_task_factory(loop) # Create a top-level _AwaitableGenlet that all nested runs will use # to yield their futures - _coro = cls(coro) + _coro: '_AwaitableGenlet' = cls(coro) return await _coro return coro_f() - def __init__(self, coro): + def __init__(self, coro: Coroutine): self._coro = coro - def __await__(self): - coro = self._coro - is_started = inspect.iscoroutine(coro) and coro.cr_running + def __await__(self) -> Generator: + """ + Make the _AwaitableGenlet awaitable. + + :returns: A generator that yields from the wrapped coroutine. + """ + coro: Coroutine = self._coro + is_started: bool = inspect.iscoroutine(coro) and coro.cr_running - def genf(): + def genf() -> Generator: + """ + generator function + """ gen = _Genlet.from_coro(coro) - value = None - excep = None + value: Optional[_Genlet] = None + excep: Optional[BaseException] = None # The coroutine is already started, so we need to dispatch the # value from the upcoming send() to the gen without running @@ -468,25 +572,35 @@ def genf(): gen = genf() if is_started: # Start the generator so it waits at the first yield point - gen.gen_send(None) + cast(_Genlet, gen).gen_send(None) return gen -def _allow_nested_run(coro): +def _allow_nested_run(coro: Coroutine) -> Coroutine: + """ + If the current callstack does not have an enclosing _Genlet, wrap the coroutine + using _AwaitableGenlet; otherwise, return the coroutine unchanged. + + :param coro: The coroutine to potentially wrap. + :returns: The original coroutine or a wrapped awaitable coroutine. + """ if _Genlet.get_enclosing() is None: return _AwaitableGenlet.wrap_coro(coro) else: return coro -def allow_nested_run(coro): +def allow_nested_run(coro: Coroutine) -> Coroutine: """ Wrap the coroutine ``coro`` such that nested calls to :func:`run` will be - allowed. + allowed. This is useful when a coroutine needs to yield control to another layer. .. warning:: The coroutine needs to be consumed in the same OS thread it was created in. + + :param coro: The coroutine to wrap. + :returns: A possibly wrapped coroutine that allows nested execution. """ return _allow_nested_run(coro) @@ -503,7 +617,13 @@ def allow_nested_run(coro): ) -def _check_executor_alive(executor): +def _check_executor_alive(executor: ThreadPoolExecutor) -> bool: + """ + Check if the given ThreadPoolExecutor is still alive by submitting a no-op job. + + :param executor: The ThreadPoolExecutor to check. + :returns: True if the executor accepts new jobs; False otherwise. + """ try: executor.submit(lambda: None) except RuntimeError: @@ -513,29 +633,37 @@ def _check_executor_alive(executor): _PATCHED_LOOP_LOCK = threading.Lock() -_PATCHED_LOOP = WeakSet() -def _install_task_factory(loop): +_PATCHED_LOOP: WeakSet = WeakSet() + + +def _install_task_factory(loop: AbstractEventLoop): """ Install a task factory on the given event ``loop`` so that top-level coroutines are wrapped using :func:`allow_nested_run`. This ensures that the nested :func:`run` infrastructure will be available. + + :param loop: The asyncio event loop on which to install the task factory. """ - def install(loop): + def install(loop: AbstractEventLoop) -> None: + """ + install the task factory on the event loop + """ if sys.version_info >= (3, 11): - def default_factory(loop, coro, context=None): + def default_factory(loop: AbstractEventLoop, coro: Coroutine, context: Optional[Context] = None) -> Optional[Task]: return asyncio.Task(coro, loop=loop, context=context) else: - def default_factory(loop, coro, context=None): + def default_factory(loop: AbstractEventLoop, coro: Coroutine, context: Optional[Context] = None) -> Optional[Task]: return asyncio.Task(coro, loop=loop) make_task = loop.get_task_factory() or default_factory - def factory(loop, coro, context=None): + + def factory(loop: AbstractEventLoop, coro: Coroutine, context: Optional[Context] = None) -> Optional[Task]: # Make sure each Task will be able to yield on behalf of its nested # await beneath blocking layers coro = _AwaitableGenlet.wrap_coro(coro) - return make_task(loop, coro, context=context) + return cast(Callable, make_task)(loop, coro, context=context) - loop.set_task_factory(factory) + loop.set_task_factory(cast(Callable, factory)) with _PATCHED_LOOP_LOCK: if loop in _PATCHED_LOOP: @@ -545,13 +673,16 @@ def factory(loop, coro, context=None): _PATCHED_LOOP.add(loop) -def _set_current_context(ctx): +def _set_current_context(ctx: Optional[Context]) -> None: """ Get all the variable from the passed ``ctx`` and set them in the current context. + + :param ctx: A Context object containing variable values to set. """ - for var, val in ctx.items(): - var.set(val) + if ctx: + for var, val in ctx.items(): + var.set(val) class _CoroRunner(abc.ABC): @@ -564,10 +695,22 @@ class _CoroRunner(abc.ABC): single event loop. """ @abc.abstractmethod - def _run(self, coro): + def _run(self, coro: Coroutine) -> Any: + """ + Execute the given coroutine using the runner's mechanism. + + :param coro: The coroutine to run. + """ pass - def run(self, coro): + def run(self, coro: Coroutine) -> Any: + """ + Run the provided coroutine using the implemented runner. Raises an + assertion error if the coroutine is already running. + + :param coro: The coroutine to run. + :returns: The result of the coroutine. + """ # Ensure we have a fresh coroutine. inspect.getcoroutinestate() does not # work on all objects that asyncio creates on some version of Python, such # as iterable_coroutine @@ -588,26 +731,38 @@ class _ThreadCoroRunner(_CoroRunner): Critically, this allows running multiple coroutines out of the same thread, which will be reserved until the runner ``__exit__`` method is called. + + :param future: A Future representing the thread running the coroutine loop. + :param jobq: A SimpleQueue for scheduling coroutine jobs. + :param resq: A SimpleQueue to collect results from executed coroutines. """ - def __init__(self, future, jobq, resq): + def __init__(self, future: 'Future', jobq: 'SimpleQueue[Optional[Tuple[Context, Coroutine]]]', + resq: 'SimpleQueue[Tuple[Context, Optional[BaseException], Any]]'): self._future = future self._jobq = jobq self._resq = resq @staticmethod - def _thread_f(jobq, resq): - def handle_jobs(runner): + def _thread_f(jobq: 'SimpleQueue[Optional[Tuple[Context, Coroutine]]]', + resq: 'SimpleQueue[Tuple[Context, Optional[BaseException], Any]]') -> None: + """ + Thread function that continuously processes scheduled coroutine jobs. + + :param jobq: Queue of jobs. + :param resq: Queue to store results from the jobs. + """ + def handle_jobs(runner: _LoopCoroRunner) -> None: while True: - job = jobq.get() + job: Optional[Tuple[Context, Coroutine]] = jobq.get() if job is None: return else: ctx, coro = job try: - value = ctx.run(runner.run, coro) + value: Any = ctx.run(runner.run, coro) except BaseException as e: value = None - excep = e + excep: Optional[BaseException] = e else: excep = None @@ -617,12 +772,19 @@ def handle_jobs(runner): handle_jobs(runner) @classmethod - def from_executor(cls, executor): - jobq = queue.SimpleQueue() - resq = queue.SimpleQueue() + def from_executor(cls, executor: ThreadPoolExecutor) -> '_ThreadCoroRunner': + """ + Create a _ThreadCoroRunner by submitting the thread function to an executor. + + :param executor: A ThreadPoolExecutor to run the coroutine loop. + :returns: An instance of _ThreadCoroRunner. + :raises RuntimeError: If the executor is not alive. + """ + jobq: SimpleQueue[Optional[Tuple[Context, Coroutine]]] = queue.SimpleQueue() + resq: SimpleQueue = queue.SimpleQueue() try: - future = executor.submit(cls._thread_f, jobq, resq) + future: Future = executor.submit(cls._thread_f, jobq, resq) except RuntimeError as e: if _check_executor_alive(executor): raise e @@ -635,7 +797,14 @@ def from_executor(cls, executor): future=future, ) - def _run(self, coro): + def _run(self, coro: Coroutine) -> Any: + """ + Schedule and run a coroutine in the separate thread, waiting for its result. + + :param coro: The coroutine to execute. + :returns: The result from running the coroutine. + :raises Exception: Propagates any exception raised by the coroutine. + """ ctx = contextvars.copy_context() self._jobq.put((ctx, coro)) ctx, excep, value = self._resq.get() @@ -659,20 +828,29 @@ class _LoopCoroRunner(_CoroRunner): The passed event loop is assumed to not be running. If ``None`` is passed, a new event loop will be created in ``__enter__`` and closed in ``__exit__``. + + :param loop: An event loop to use; if None, a new one is created. """ - def __init__(self, loop): + def __init__(self, loop: Optional[AbstractEventLoop]): self.loop = loop - self._owned = False + self._owned: bool = False - def _run(self, coro): + def _run(self, coro: Coroutine) -> Any: + """ + Run the given coroutine to completion on the event loop and return its result. + + :param coro: The coroutine to run. + :returns: The result of the coroutine. + """ loop = self.loop # Back-propagate the contextvars that could have been modified by the # coroutine. This could be handled by asyncio.Runner().run(..., # context=...) or loop.create_task(..., context=...) but these APIs are # only available since Python 3.11 - ctx = None - async def capture_ctx(): + ctx: Optional[Context] = None + + async def capture_ctx() -> Any: nonlocal ctx try: return await _allow_nested_run(coro) @@ -680,12 +858,13 @@ async def capture_ctx(): ctx = contextvars.copy_context() try: - return loop.run_until_complete(capture_ctx()) + if loop: + return loop.run_until_complete(capture_ctx()) finally: _set_current_context(ctx) - def __enter__(self): - loop = self.loop + def __enter__(self) -> '_LoopCoroRunner': + loop: Optional[AbstractEventLoop] = self.loop if loop is None: owned = True loop = asyncio.new_event_loop() @@ -708,16 +887,33 @@ class _GenletCoroRunner(_CoroRunner): """ Run a coroutine assuming one of the parent coroutines was wrapped with :func:`allow_nested_run`. + + :param g: The enclosing _Genlet instance. """ - def __init__(self, g): + def __init__(self, g: _Genlet): self._g = g - def _run(self, coro): + def _run(self, coro: Coroutine) -> Any: + """ + Execute the coroutine by delegating to the enclosing _Genlet's consume_coro method. + + :param coro: The coroutine to run. + :returns: The result of the coroutine. + """ return self._g.consume_coro(coro, None) -def _get_runner(): - executor = _CORO_THREAD_EXECUTOR +def _get_runner() -> Union[_GenletCoroRunner, + _LoopCoroRunner, + _ThreadCoroRunner]: + """ + Determine the appropriate coroutine runner based on the current context. + Returns a _GenletCoroRunner if an enclosing _Genlet is present, a _LoopCoroRunner + if an event loop exists (or can be created), or a _ThreadCoroRunner if an event loop is running. + + :returns: A coroutine runner appropriate for the current execution context. + """ + executor: ThreadPoolExecutor = _CORO_THREAD_EXECUTOR g = _Genlet.get_enclosing() try: loop = asyncio.get_running_loop() @@ -748,7 +944,7 @@ def _get_runner(): return _ThreadCoroRunner.from_executor(executor) -def run(coro): +def run(coro: Coroutine) -> Any: """ Similar to :func:`asyncio.run` but can be called while an event loop is running if a coroutine higher in the callstack has been wrapped using @@ -759,13 +955,16 @@ def run(coro): be reflected in the context of the caller. This allows context variable updates to cross an arbitrary number of run layers, as if all those layers were just part of the same coroutine. + + :param coro: The coroutine to execute. + :returns: The result of the coroutine. """ runner = _get_runner() with runner as runner: return runner.run(coro) -def asyncf(f): +def asyncf(f: Callable): """ Decorator used to turn a coroutine into a blocking function, with an optional asynchronous API. @@ -787,17 +986,20 @@ async def foo(x): This allows the same implementation to be both used as blocking for ease of use and backward compatibility, or exposed as a corountine for callers that can deal with awaitables. + + :param f: The asynchronous function to decorate. + :returns: A callable that runs f synchronously, with an asynchronous version available as .asyn. """ @functools.wraps(f) - def blocking(*args, **kwargs): + def blocking(*args, **kwargs) -> Any: # Since run() needs a corountine, make sure we provide one - async def wrapper(): + async def wrapper() -> Generator: x = f(*args, **kwargs) # Async generators have to be consumed and accumulated in a list # before crossing a blocking boundary. if inspect.isasyncgen(x): - def genf(): + def genf() -> Generator: asyncgen = x.__aiter__() while True: try: @@ -817,18 +1019,22 @@ def genf(): class _AsyncPolymorphicCMState: - def __init__(self): - self.nesting = 0 - self.runner = None + def __init__(self) -> None: + self.nesting: int = 0 + self.runner: Optional[Union[_GenletCoroRunner, + _LoopCoroRunner, + _ThreadCoroRunner]] = None - def _update_nesting(self, n): + def _update_nesting(self, n: int) -> bool: x = self.nesting assert x >= 0 x = x + n self.nesting = x return bool(x) - def _get_runner(self): + def _get_runner(self) -> Optional[Union[_GenletCoroRunner, + _LoopCoroRunner, + _ThreadCoroRunner]]: runner = self.runner if runner is None: assert not self.nesting @@ -837,8 +1043,8 @@ def _get_runner(self): self.runner = runner return runner - def _cleanup_runner(self, force=False): - def cleanup(): + def _cleanup_runner(self, force: bool = False) -> None: + def cleanup() -> None: self.runner = None if runner is not None: runner.__exit__(None, None, None) @@ -856,13 +1062,21 @@ class _AsyncPolymorphicCM: """ Wrap an async context manager such that it exposes a synchronous API as well for backward compatibility. + + :param async_cm: The asynchronous context manager to wrap. """ - def __init__(self, async_cm): + def __init__(self, async_cm: AsyncContextManager): self.cm = async_cm - self._state = threading.local() + self._state: local = threading.local() def _get_state(self): + """ + Retrieve or initialize the thread-local state for this context manager. + + :returns: The state object. + :rtype: _AsyncPolymorphicCMState + """ try: return self._state.x except AttributeError: @@ -870,7 +1084,10 @@ def _get_state(self): self._state.x = state return state - def _delete_state(self): + def _delete_state(self) -> None: + """ + Delete the thread-local state. + """ try: del self._state.x except AttributeError: @@ -883,33 +1100,39 @@ def __aexit__(self, *args, **kwargs): return self.cm.__aexit__(*args, **kwargs) @staticmethod - def _exit(state): + def _exit(state: _AsyncPolymorphicCMState) -> None: state._update_nesting(-1) state._cleanup_runner() - def __enter__(self, *args, **kwargs): - state = self._get_state() - runner = state._get_runner() + def __enter__(self, *args, **kwargs) -> Any: + state: _AsyncPolymorphicCMState = self._get_state() + runner: Optional[Union[_GenletCoroRunner, + _LoopCoroRunner, + _ThreadCoroRunner]] = state._get_runner() # Increase the nesting count _before_ we start running the # coroutine, in case it is a recursive context manager state._update_nesting(1) try: - coro = self.cm.__aenter__(*args, **kwargs) - return runner.run(coro) + coro: Coroutine = self.cm.__aenter__(*args, **kwargs) + if runner: + return runner.run(coro) except BaseException: self._exit(state) raise - def __exit__(self, *args, **kwargs): - coro = self.cm.__aexit__(*args, **kwargs) + def __exit__(self, *args, **kwargs) -> Any: + coro: Coroutine = self.cm.__aexit__(*args, **kwargs) - state = self._get_state() - runner = state._get_runner() + state: _AsyncPolymorphicCMState = self._get_state() + runner: Optional[Union[_GenletCoroRunner, + _LoopCoroRunner, + _ThreadCoroRunner]] = state._get_runner() try: - return runner.run(coro) + if runner: + return runner.run(coro) finally: self._exit(state) @@ -917,16 +1140,22 @@ def __del__(self): self._get_state()._cleanup_runner(force=True) -def asynccontextmanager(f): +T = TypeVar('T') + + +def asynccontextmanager(f: Callable[..., AsyncGenerator[T, None]]) -> Callable[..., _AsyncPolymorphicCM]: """ Same as :func:`contextlib.asynccontextmanager` except that it can also be used with a regular ``with`` statement for backward compatibility. + + :param f: A callable that returns an asynchronous generator. + :returns: A context manager supporting both synchronous and asynchronous usage. """ - f = contextlib.asynccontextmanager(f) + f_int = contextlib.asynccontextmanager(f) - @functools.wraps(f) - def wrapper(*args, **kwargs): - cm = f(*args, **kwargs) + @functools.wraps(f_int) + def wrapper(*args: Any, **kwargs: Any) -> _AsyncPolymorphicCM: + cm = f_int(*args, **kwargs) return _AsyncPolymorphicCM(cm) return wrapper @@ -935,46 +1164,53 @@ def wrapper(*args, **kwargs): class ConcurrentAccessBase(abc.ABC): """ Abstract Base Class for resources tracked by :func:`concurrently`. + Subclasses must implement the method to determine if two resources overlap. """ @abc.abstractmethod - def overlap_with(self, other): + def overlap_with(self, other: 'ConcurrentAccessBase') -> bool: """ Return ``True`` if the resource overlaps with the given one. :param other: Resources that should not overlap with ``self``. - :type other: devlib.utils.asym.ConcurrentAccessBase + :returns: True if the two resources overlap; False otherwise. .. note:: It is guaranteed that ``other`` will be a subclass of our class. """ + class PathAccess(ConcurrentAccessBase): """ Concurrent resource representing a file access. :param namespace: Identifier of the namespace of the path. One of "target" or "host". - :type namespace: str :param path: Normalized path to the file. - :type path: str :param mode: Opening mode of the file. Can be ``"r"`` for read and ``"w"`` for writing. - :type mode: str """ - def __init__(self, namespace, path, mode): + def __init__(self, namespace: str, path: str, mode: str): assert namespace in ('host', 'target') self.namespace = namespace assert mode in ('r', 'w') self.mode = mode self.path = os.path.abspath(path) if namespace == 'host' else os.path.normpath(path) - def overlap_with(self, other): + def overlap_with(self, other: ConcurrentAccessBase) -> bool: + """ + Check if this path access overlaps with another access, considering + namespace, mode, and filesystem hierarchy. + + :param other: Another resource access instance. + :returns: True if the two paths overlap (and one of the accesses is for writing), else False. + """ + other_internal = cast('PathAccess', other) path1 = pathlib.Path(self.path).resolve() - path2 = pathlib.Path(other.path).resolve() + path2 = pathlib.Path(other_internal.path).resolve() return ( - self.namespace == other.namespace and - 'w' in (self.mode, other.mode) and + self.namespace == other_internal.namespace and + 'w' in (self.mode, other_internal.mode) and ( path1 == path2 or path1 in path2.parents or @@ -983,6 +1219,11 @@ def overlap_with(self, other): ) def __str__(self): + """ + Return a string representation of the PathAccess, including the path and mode. + + :returns: A string describing the path access. + """ mode = { 'r': 'read', 'w': 'write', diff --git a/devlib/utils/gem5.py b/devlib/utils/gem5.py index cc48c0723..78e6616af 100644 --- a/devlib/utils/gem5.py +++ b/devlib/utils/gem5.py @@ -13,9 +13,9 @@ # limitations under the License. import re -import logging from devlib.utils.types import numeric +from devlib.utils.misc import get_logger GEM5STATS_FIELD_REGEX = re.compile(r"^(?P[^- ]\S*) +(?P[^#]+).+$") @@ -23,7 +23,7 @@ GEM5STATS_DUMP_TAIL = '---------- End Simulation Statistics ----------' GEM5STATS_ROI_NUMBER = 8 -logger = logging.getLogger('gem5') +logger = get_logger('gem5') def iter_statistics_dump(stats_file): diff --git a/devlib/utils/misc.py b/devlib/utils/misc.py index 1c49d0d0b..2e5f25e48 100644 --- a/devlib/utils/misc.py +++ b/devlib/utils/misc.py @@ -1,4 +1,4 @@ -# Copyright 2013-2024 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ from operator import itemgetter from weakref import WeakSet from ruamel.yaml import YAML +from ruamel.yaml.error import YAMLError, MarkedYAMLError +from devlib.utils.annotation_helpers import SubprocessCommand import ctypes import logging @@ -39,28 +41,35 @@ import warnings import wrapt - try: from contextlib import ExitStack except AttributeError: - from contextlib2 import ExitStack + from contextlib2 import ExitStack # type: ignore from shlex import quote -from past.builtins import basestring # pylint: disable=redefined-builtin from devlib.exception import HostError, TimeoutError +from typing import (Union, List, Optional, Tuple, Set, + Any, Callable, Dict, TYPE_CHECKING, + Type, cast, Pattern) +from collections.abc import Generator +from typing_extensions import Literal +if TYPE_CHECKING: + from logging import Logger + from tarfile import TarFile, TarInfo + from devlib import Target # ABI --> architectures list -ABI_MAP = { +ABI_MAP: Dict[str, List[str]] = { 'armeabi': ['armeabi', 'armv7', 'armv7l', 'armv7el', 'armv7lh', 'armeabi-v7a'], 'arm64': ['arm64', 'armv8', 'arm64-v8a', 'aarch64'], } # Vendor ID --> CPU part ID --> CPU variant ID --> Core Name # None means variant is not used. -CPU_PART_MAP = { +CPU_PART_MAP: Dict[int, Dict[int, Dict[Optional[int], str]]] = { 0x41: { # ARM 0x926: {None: 'ARM926'}, 0x946: {None: 'ARM946'}, @@ -127,16 +136,30 @@ } -def get_cpu_name(implementer, part, variant): +def get_cpu_name(implementer: int, part: int, variant: int) -> Optional[str]: + """ + Retrieve the CPU name based on implementer, part, and variant IDs using the CPU_PART_MAP. + + :param implementer: The vendor identifier. + :param part: The CPU part identifier. + :param variant: The CPU variant identifier. + :returns: The CPU name if found; otherwise, None. + """ part_data = CPU_PART_MAP.get(implementer, {}).get(part, {}) if None in part_data: # variant does not determine core Name for this vendor - name = part_data[None] + name: Optional[str] = part_data[None] else: name = part_data.get(variant) return name -def preexec_function(): +def preexec_function() -> None: + """ + Set the process group ID for the current process so that a subprocess and all its children + can later be killed together. This function is Unix-specific. + + :raises OSError: If setting the process group fails. + """ # Change process group in case we have to kill the subprocess and all of # its children later. # TODO: this is Unix-specific; would be good to find an OS-agnostic way @@ -144,22 +167,53 @@ def preexec_function(): os.setpgrp() -check_output_logger = logging.getLogger('check_output') +def get_logger(name: str) -> logging.Logger: + return logging.getLogger(name) + +check_output_logger: 'Logger' = get_logger('check_output') -def get_subprocess(command, **kwargs): + +def get_subprocess(command: SubprocessCommand, **kwargs) -> subprocess.Popen: + """ + Launch a subprocess to run the specified command, overriding stdout to PIPE. + The process is set to a new process group via a preexec function. + + :param command: The command to execute. + :param kwargs: Additional keyword arguments to pass to subprocess.Popen. + :raises ValueError: If 'stdout' is provided in kwargs. + :returns: A subprocess.Popen object running the command. + """ if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') return subprocess.Popen(command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - stdin=subprocess.PIPE, - preexec_fn=preexec_function, - **kwargs) - + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE, + preexec_fn=preexec_function, + **kwargs) + + +def check_subprocess_output( + process: subprocess.Popen, + timeout: Optional[float] = None, + ignore: Optional[Union[int, List[int], Literal['all']]] = None, + inputtext: Union[str, bytes, None] = None) -> Tuple[str, str]: + """ + Communicate with the given subprocess and return its decoded output and error streams. + This function handles timeouts and can ignore specified return codes. + + :param process: The subprocess.Popen instance to interact with. + :param timeout: The maximum time in seconds to wait for the process to complete. + :param ignore: A return code (or list of codes) to ignore; use "all" to ignore all nonzero codes. + :param inputtext: Optional text or bytes to send to the process's stdin. + :returns: A tuple (output, error) with decoded strings. + :raises ValueError: If the ignore parameter is improperly formatted. + :raises TimeoutError: If the process does not complete before the timeout expires. + :raises subprocess.CalledProcessError: If the process exits with a nonzero code not in ignore. + """ + output: Union[str, bytes] = '' + error: Union[str, bytes] = '' -def check_subprocess_output(process, timeout=None, ignore=None, inputtext=None): - output = None - error = None # pylint: disable=too-many-branches if ignore is None: ignore = [] @@ -170,39 +224,70 @@ def check_subprocess_output(process, timeout=None, ignore=None, inputtext=None): raise ValueError(message.format(ignore)) with process: + timeout_expired: Optional[subprocess.TimeoutExpired] = None try: output, error = process.communicate(inputtext, timeout=timeout) except subprocess.TimeoutExpired as e: timeout_expired = e - else: - timeout_expired = None # Currently errors=replace is needed as 0x8c throws an error - output = output.decode(sys.stdout.encoding or 'utf-8', "replace") if output else '' - error = error.decode(sys.stderr.encoding or 'utf-8', "replace") if error else '' + output = cast(str, output.decode(sys.stdout.encoding or 'utf-8', "replace") if isinstance(output, bytes) else output) + error = cast(str, error.decode(sys.stderr.encoding or 'utf-8', "replace") if isinstance(error, bytes) else error) if timeout_expired: raise TimeoutError(process.args, output='\n'.join([output, error])) - retcode = process.returncode + retcode: int = process.returncode if retcode and ignore != 'all' and retcode not in ignore: raise subprocess.CalledProcessError(retcode, process.args, output, error) return output, error -def check_output(command, timeout=None, ignore=None, inputtext=None, **kwargs): - """This is a version of subprocess.check_output that adds a timeout parameter to kill - the subprocess if it does not return within the specified time.""" +def check_output(command: SubprocessCommand, timeout: Optional[int] = None, + ignore: Optional[Union[int, List[int], Literal['all']]] = None, + inputtext: Union[str, bytes, None] = None, **kwargs) -> Tuple[str, str]: + """ + This is a version of subprocess.check_output that adds a timeout parameter to kill + the subprocess if it does not return within the specified time. + + :param command: The command to execute. + :param timeout: Time in seconds to wait for the command to complete. + :param ignore: A return code or list of return codes to ignore, or "all" to ignore all. + :param inputtext: Optional text or bytes to send to the command's stdin. + :param kwargs: Additional keyword arguments for subprocess.Popen. + :returns: A tuple (stdout, stderr) of the command's decoded output. + :raises TimeoutError: If the command does not complete in time. + :raises subprocess.CalledProcessError: If the command fails and its return code is not ignored. + """ process = get_subprocess(command, **kwargs) return check_subprocess_output(process, timeout=timeout, ignore=ignore, inputtext=inputtext) -def walk_modules(path): +class ExtendedHostError(HostError): + """ + Exception class that extends HostError with additional attributes. + + :param message: The error message. + :param module: The name of the module where the error originated. + :param exc_info: Exception information from sys.exc_info(). + :param orig_exc: The original exception that was caught. + """ + def __init__(self, message: str, module: Optional[str] = None, + exc_info: Any = None, orig_exc: Optional[Exception] = None): + super().__init__(message) + self.module = module + self.exc_info = exc_info + self.orig_exc = orig_exc + + +def walk_modules(path: str) -> List[types.ModuleType]: """ Given package name, return a list of all modules (including submodules, etc) in that package. + :param path: The package name to walk (e.g., 'mypackage'). + :returns: A list of module objects. :raises HostError: if an exception is raised while trying to import one of the modules under ``path``. The exception will have addtional attributes set: ``module`` will be set to the qualified name @@ -211,39 +296,50 @@ def walk_modules(path): """ - def __try_import(path): + def __try_import(path: str) -> types.ModuleType: try: return __import__(path, {}, {}, ['']) except Exception as e: he = HostError('Could not load {}: {}'.format(path, str(e))) - he.module = path - he.exc_info = sys.exc_info() - he.orig_exc = e + cast(ExtendedHostError, he).module = path + cast(ExtendedHostError, he).exc_info = sys.exc_info() + cast(ExtendedHostError, he).orig_exc = e raise he - root_mod = __try_import(path) - mods = [root_mod] + root_mod: types.ModuleType = __try_import(path) + mods: List[types.ModuleType] = [root_mod] if not hasattr(root_mod, '__path__'): # root is a module not a package -- nothing to walk return mods for _, name, ispkg in pkgutil.iter_modules(root_mod.__path__): - submod_path = '.'.join([path, name]) + submod_path: str = '.'.join([path, name]) if ispkg: mods.extend(walk_modules(submod_path)) else: - submod = __try_import(submod_path) + submod: types.ModuleType = __try_import(submod_path) mods.append(submod) return mods -def redirect_streams(stdout, stderr, command): + +def redirect_streams(stdout: int, stderr: int, + command: SubprocessCommand) -> Tuple[int, int, SubprocessCommand]: """ - Update a command to redirect a given stream to /dev/null if it's - ``subprocess.DEVNULL``. + Adjust a command string to redirect output streams to specific targets. + If a stream is set to subprocess.DEVNULL, it replaces it with a redirect + to /dev/null; for subprocess.STDOUT, it merges stderr into stdout. + + :param stdout: The desired stdout value. + :param stderr: The desired stderr value. + :param command: The original command to run. :return: A tuple (stdout, stderr, command) with stream set to ``subprocess.PIPE`` if the `stream` parameter was set to ``subprocess.DEVNULL``. """ - def redirect(stream, redirection): + + def redirect(stream: int, redirection: str) -> Tuple[int, str]: + """ + redirect output and error streams + """ if stream == subprocess.DEVNULL: suffix = '{}/dev/null'.format(redirection) elif stream == subprocess.STDOUT: @@ -259,47 +355,76 @@ def redirect(stream, redirection): stdout, suffix1 = redirect(stdout, '>') stderr, suffix2 = redirect(stderr, '2>') - command = 'sh -c {} {} {}'.format(quote(command), suffix1, suffix2) + command = 'sh -c {} {} {}'.format(quote(cast(str, command)), suffix1, suffix2) return (stdout, stderr, command) -def ensure_directory_exists(dirpath): + +def ensure_directory_exists(dirpath: str) -> str: """A filter for directory paths to ensure they exist.""" if not os.path.isdir(dirpath): os.makedirs(dirpath) return dirpath -def ensure_file_directory_exists(filepath): +def ensure_file_directory_exists(filepath: str) -> str: """ A filter for file paths to ensure the directory of the file exists and the file can be created there. The file itself is *not* going to be created if it doesn't already exist. + :param dirpath: The directory path to check. + :returns: The directory path. + :raises OSError: If the directory cannot be created """ ensure_directory_exists(os.path.dirname(filepath)) return filepath -def merge_dicts(*args, **kwargs): +def merge_dicts(*args, **kwargs) -> Dict: + """ + Merge multiple dictionaries together. + + :param args: Two or more dictionaries to merge. + :param kwargs: Additional keyword arguments to pass to the merging function. + :returns: A new dictionary containing the merged keys and values. + :raises ValueError: If fewer than two dictionaries are provided. + """ if not len(args) >= 2: raise ValueError('Must specify at least two dicts to merge.') - func = partial(_merge_two_dicts, **kwargs) + func: partial[Dict] = partial(_merge_two_dicts, **kwargs) return reduce(func, args) -def _merge_two_dicts(base, other, list_duplicates='all', match_types=False, # pylint: disable=R0912,R0914 - dict_type=dict, should_normalize=True, should_merge_lists=True): - """Merge dicts normalizing their keys.""" +def _merge_two_dicts(base: Dict, other: Dict, list_duplicates: str = 'all', + match_types: bool = False, # pylint: disable=R0912,R0914 + dict_type: Type[Dict] = dict, should_normalize: bool = True, + should_merge_lists: bool = True) -> Dict: + """ + Merge two dictionaries recursively, normalizing their keys. The merging behavior + for lists and duplicate keys can be controlled via parameters. + + :param base: The first dictionary. + :param other: The second dictionary to merge into the first. + :param list_duplicates: Strategy for handling duplicate list entries ("all", "first", or "last"). + :param match_types: If True, enforce that overlapping keys have the same type. + :param dict_type: The dictionary type to use for constructing merged dictionaries. + :param should_normalize: If True, normalize keys/values during merge. + :param should_merge_lists: If True, merge lists; otherwise, override base list. + :returns: A merged dictionary. + :raises ValueError: If there is a type mismatch for a key when match_types is True. + :raises AssertionError: If an unexpected merge key is encountered. + """ merged = dict_type() base_keys = list(base.keys()) other_keys = list(other.keys()) - norm = normalize if should_normalize else lambda x, y: x + # FIXME - annotate the lambda. type checker is not able to deduce its type + norm: Callable = normalize if should_normalize else lambda x, y: x # type:ignore - base_only = [] - other_only = [] - both = [] - union = [] + base_only: List = [] + other_only: List = [] + both: List = [] + union: List = [] for k in base_keys: if k in other_keys: both.append(k) @@ -345,50 +470,70 @@ def _merge_two_dicts(base, other, list_duplicates='all', match_types=False, # p return merged -def merge_lists(*args, **kwargs): +def merge_lists(*args, **kwargs) -> List: + """ + Merge multiple lists together. + + :param args: Two or more lists to merge. + :param kwargs: Additional keyword arguments to pass to the merging function. + :returns: A merged list containing the combined items. + :raises ValueError: If fewer than two lists are provided. + """ if not len(args) >= 2: raise ValueError('Must specify at least two lists to merge.') func = partial(_merge_two_lists, **kwargs) return reduce(func, args) -def _merge_two_lists(base, other, duplicates='all', dict_type=dict): # pylint: disable=R0912 +def _merge_two_lists(base: List, other: List, duplicates: str = 'all', + dict_type: Type[Dict] = dict) -> List: # pylint: disable=R0912 """ Merge lists, normalizing their entries. - parameters: + :param base: The base list. + :param other: The list to merge into base. + :param duplicates: Indicates the strategy of handling entries that appear + in both lists. ``all`` will keep occurrences from both + lists; ``first`` will only keep occurrences from + ``base``; ``last`` will only keep occurrences from + ``other``; - :base, other: the two lists to be merged. ``other`` will be merged on - top of base. - :duplicates: Indicates the strategy of handling entries that appear - in both lists. ``all`` will keep occurrences from both - lists; ``first`` will only keep occurrences from - ``base``; ``last`` will only keep occurrences from - ``other``; - - .. note:: duplicate entries that appear in the *same* list + .. note:: duplicate entries that appear in the *same* list will never be removed. - + :param dict_type: The dictionary type to use for normalization if needed. + :returns: A merged list with duplicate handling applied. + :raises ValueError: If an unexpected value is provided for duplicates. """ if not isiterable(base): base = [base] if not isiterable(other): other = [other] if duplicates == 'all': - merged_list = [] - for v in normalize(base, dict_type) + normalize(other, dict_type): + merged_list: List = [] + combined: List = [] + normalized_base = normalize(base, dict_type) + normalized_other = normalize(other, dict_type) + if isinstance(normalized_base, (list, tuple)) and isinstance(normalized_other, (list, tuple)): + combined = list(normalized_base) + list(normalized_other) + elif isinstance(normalized_base, dict) and isinstance(normalized_other, dict): + combined = [normalized_base, normalized_other] + elif isinstance(normalized_base, set) and isinstance(normalized_other, set): + combined = list(normalized_base.union(normalized_other)) + else: + combined = list(normalized_base) + list(normalized_other) + for v in combined: if not _check_remove_item(merged_list, v): merged_list.append(v) return merged_list elif duplicates == 'first': base_norm = normalize(base, dict_type) - merged_list = normalize(base, dict_type) + merged_list = cast(List, normalize(base, dict_type)) for v in base_norm: _check_remove_item(merged_list, v) for v in normalize(other, dict_type): if not _check_remove_item(merged_list, v): if v not in base_norm: - merged_list.append(v) # pylint: disable=no-member + cast(List, merged_list).append(v) # pylint: disable=no-member return merged_list elif duplicates == 'last': other_norm = normalize(other, dict_type) @@ -406,11 +551,16 @@ def _merge_two_lists(base, other, duplicates='all', dict_type=dict): # pylint: 'Must be in {"all", "first", "last"}.') -def _check_remove_item(the_list, item): - """Helper function for merge_lists that implements checking wether an items - should be removed from the list and doing so if needed. Returns ``True`` if - the item has been removed and ``False`` otherwise.""" - if not isinstance(item, basestring): +def _check_remove_item(the_list: List, item: Any) -> bool: + """ + Check whether an item should be removed from a list based on certain criteria. + If the item is a string starting with '~', its unprefixed version is removed from the list. + + :param the_list: The list in which to check for the item. + :param item: The item to check. + :returns: True if the item was removed; False otherwise. + """ + if not isinstance(item, str): return False if not item.startswith('~'): return False @@ -420,9 +570,16 @@ def _check_remove_item(the_list, item): return True -def normalize(value, dict_type=dict): - """Normalize values. Recursively normalizes dict keys to be lower case, - no surrounding whitespace, underscore-delimited strings.""" +def normalize(value: Union[Dict, List, Tuple, Set], + dict_type: Type[Dict] = dict) -> Union[Dict, List, Tuple, Set]: + """ + Recursively normalize values by converting dictionary keys to lower-case, + stripping whitespace, and replacing spaces with underscores. + + :param value: A dict, list, tuple, or set to normalize. + :param dict_type: The dictionary type to use for normalized dictionaries. + :returns: A normalized version of the input value. + """ if isinstance(value, dict): normalized = dict_type() for k, v in value.items(): @@ -437,12 +594,25 @@ def normalize(value, dict_type=dict): return value -def convert_new_lines(text): - """ Convert new lines to a common format. """ +def convert_new_lines(text: str) -> str: + """ + Convert different newline conventions to a single '\n' format. + + :param text: The input text. + :returns: The text with unified newline characters. + """ return text.replace('\r\n', '\n').replace('\r', '\n') -def sanitize_cmd_template(cmd): - msg = ( + +def sanitize_cmd_template(cmd: str) -> str: + """ + Replace quoted placeholders with unquoted ones in a command template, + warning the user if quoted placeholders are detected. + + :param cmd: The command template string. + :returns: The sanitized command template. + """ + msg: str = ( '''Quoted placeholder should not be used, as it will result in quoting the text twice. {} should be used instead of '{}' or "{}" in the template: ''' ) for unwanted in ('"{}"', "'{}'"): @@ -452,51 +622,69 @@ def sanitize_cmd_template(cmd): return cmd -def escape_quotes(text): + +def escape_quotes(text: str) -> str: """ - Escape quotes, and escaped quotes, in the specified text. + Escape quotes and escaped quotes in the given text. + + .. note:: It is recommended to use shlex.quote when possible. - .. note:: :func:`shlex.quote` should be favored where possible. + :param text: The text to escape. + :returns: The text with quotes escaped. """ return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\\\'').replace('\"', '\\\"') -def escape_single_quotes(text): +def escape_single_quotes(text: str) -> str: """ - Escape single quotes, and escaped single quotes, in the specified text. + Escape single quotes in the provided text. - .. note:: :func:`shlex.quote` should be favored where possible. + .. note:: Prefer using shlex.quote when possible. + + :param text: The text to process. + :returns: The text with single quotes escaped. """ return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\'\\\'\'') -def escape_double_quotes(text): +def escape_double_quotes(text: str) -> str: """ - Escape double quotes, and escaped double quotes, in the specified text. + Escape double quotes in the given text. + + .. note:: Prefer using shlex.quote when possible. - .. note:: :func:`shlex.quote` should be favored where possible. + :param text: The input text. + :returns: The text with double quotes escaped. """ return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\"', '\\\"') -def escape_spaces(text): +def escape_spaces(text: str) -> str: """ - Escape spaces in the specified text + Escape spaces in the provided text. - .. note:: :func:`shlex.quote` should be favored where possible. + .. note:: Prefer using shlex.quote when possible. + + :param text: The text to process. + :returns: The text with spaces escaped. """ return text.replace(' ', '\\ ') -def getch(count=1): - """Read ``count`` characters from standard input.""" +def getch(count: int = 1) -> str: + """ + Read a specified number of characters from standard input. + + :param count: The number of characters to read. + :returns: A string of characters read from stdin. + """ if os.name == 'nt': import msvcrt # pylint: disable=F0401 - return ''.join([msvcrt.getch() for _ in range(count)]) + return ''.join([msvcrt.getch() for _ in range(count)]) # type:ignore else: # assume Unix import tty # NOQA import termios # NOQA - fd = sys.stdin.fileno() + fd: int = sys.stdin.fileno() old_settings = termios.tcgetattr(fd) try: tty.setraw(sys.stdin.fileno()) @@ -506,45 +694,70 @@ def getch(count=1): return ch -def isiterable(obj): - """Returns ``True`` if the specified object is iterable and - *is not a string type*, ``False`` otherwise.""" - return hasattr(obj, '__iter__') and not isinstance(obj, basestring) +def isiterable(obj: Any) -> bool: + """ + Determine if the provided object is iterable, excluding strings. + + :param obj: The object to test. + :returns: True if the object is iterable and is not a string; otherwise, False. + """ + return hasattr(obj, '__iter__') and not isinstance(obj, str) + +def as_relative(path: str) -> str: + """ + Convert an absolute path to a relative path by removing leading separators. -def as_relative(path): - """Convert path to relative by stripping away the leading '/' on UNIX or - the equivant on other platforms.""" + :param path: The absolute path. + :returns: A relative path. + """ path = os.path.splitdrive(path)[1] return path.lstrip(os.sep) -def commonprefix(file_list, sep=os.sep): +def commonprefix(file_list: List[str], sep: str = os.sep) -> str: """ - Find the lowest common base folder of a passed list of files. + Determine the lowest common base folder among a list of file paths. + + :param file_list: A list of file paths. + :param sep: The path separator to use. + :returns: The common prefix path. """ - common_path = os.path.commonprefix(file_list) - cp_split = common_path.split(sep) - other_split = file_list[0].split(sep) - last = len(cp_split) - 1 + common_path: str = os.path.commonprefix(file_list) + cp_split: List[str] = common_path.split(sep) + other_split: List[str] = file_list[0].split(sep) + last: int = len(cp_split) - 1 if cp_split[last] != other_split[last]: cp_split = cp_split[:-1] return sep.join(cp_split) -def get_cpu_mask(cores): - """Return a string with the hex for the cpu mask for the specified core numbers.""" +def get_cpu_mask(cores: List[int]) -> str: + """ + Compute a hexadecimal CPU mask for the specified core indices. + + :param cores: A list of core numbers. + :returns: A hexadecimal string representing the CPU mask. + """ mask = 0 for i in cores: mask |= 1 << i return '0x{0:x}'.format(mask) -def which(name): - """Platform-independent version of UNIX which utility.""" +def which(name: str) -> Optional[str]: + """ + Find the full path to an executable by searching the system PATH. + Provides a platform-independent implementation of the UNIX 'which' utility. + + :param name: The name of the executable to find. + :returns: The full path to the executable if found, otherwise None. + """ if os.name == 'nt': - paths = os.getenv('PATH').split(os.pathsep) - exts = os.getenv('PATHEXT').split(os.pathsep) + path_env = os.getenv('PATH') + pathext_env = os.getenv('PATHEXT') + paths: List[str] = path_env.split(os.pathsep) if path_env else [] + exts: List[str] = pathext_env.split(os.pathsep) if pathext_env else [] for path in paths: testpath = os.path.join(path, name) if os.path.isfile(testpath): @@ -562,13 +775,20 @@ def which(name): # This matches most ANSI escape sequences, not just colors -_bash_color_regex = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]') +_bash_color_regex: Pattern[str] = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]') + -def strip_bash_colors(text): +def strip_bash_colors(text: str) -> str: + """ + Remove ANSI escape sequences (commonly used for terminal colors) from the given text. + + :param text: The input string potentially containing ANSI escape sequences. + :returns: The input text with all ANSI escape sequences removed. + """ return _bash_color_regex.sub('', text) -def get_random_string(length): +def get_random_string(length: int) -> str: """Returns a random ASCII string of the specified length).""" return ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(length)) @@ -581,7 +801,7 @@ def message(self): return self.args[0] return str(self) - def __init__(self, message, filepath, lineno): + def __init__(self, message: str, filepath: str, lineno: Optional[int]): super(LoadSyntaxError, self).__init__(message) self.filepath = filepath self.lineno = lineno @@ -591,36 +811,34 @@ def __str__(self): return message.format(self.filepath, self.lineno, self.message) -def load_struct_from_yaml(filepath): +def load_struct_from_yaml(filepath: str) -> Dict: """ Parses a config structure from a YAML file. The structure should be composed of basic Python types. :param filepath: Input file which contains YAML data. - :type filepath: str :raises LoadSyntaxError: if there is a syntax error in YAML data. :return: A dictionary which contains parsed YAML data - :rtype: Dict """ try: yaml = YAML(typ='safe', pure=True) with open(filepath, 'r', encoding='utf-8') as file_handler: return yaml.load(file_handler) - except yaml.YAMLError as ex: - message = ex.message if hasattr(ex, 'message') else '' - lineno = ex.problem_mark.line if hasattr(ex, 'problem_mark') else None + except YAMLError as ex: + message = str(ex) + lineno = cast(MarkedYAMLError, ex).problem_mark.line if hasattr(ex, 'problem_mark') else None raise LoadSyntaxError(message, filepath=filepath, lineno=lineno) from ex -RAND_MOD_NAME_LEN = 30 -BAD_CHARS = string.punctuation + string.whitespace -TRANS_TABLE = str.maketrans(BAD_CHARS, '_' * len(BAD_CHARS)) +RAND_MOD_NAME_LEN: int = 30 +BAD_CHARS: str = string.punctuation + string.whitespace +TRANS_TABLE: Dict[int, int] = str.maketrans(BAD_CHARS, '_' * len(BAD_CHARS)) -def to_identifier(text): +def to_identifier(text: str) -> str: """Converts text to a valid Python identifier by replacing all whitespace and punctuation and adding a prefix if starting with a digit""" if text[:1].isdigit(): @@ -628,7 +846,7 @@ def to_identifier(text): return re.sub('_+', '_', str(text).translate(TRANS_TABLE)) -def unique(alist): +def unique(alist: List) -> List: """ Returns a list containing only unique elements from the input list (but preserves order, unlike sets). @@ -641,9 +859,9 @@ def unique(alist): return result -def ranges_to_list(ranges_string): +def ranges_to_list(ranges_string: str) -> List[int]: """Converts a sysfs-style ranges string, e.g. ``"0,2-4"``, into a list ,e.g ``[0,2,3,4]``""" - values = [] + values: List[int] = [] for rg in ranges_string.split(','): if '-' in rg: first, last = list(map(int, rg.split('-'))) @@ -653,13 +871,13 @@ def ranges_to_list(ranges_string): return values -def list_to_ranges(values): +def list_to_ranges(values: List) -> str: """Converts a list, e.g ``[0,2,3,4]``, into a sysfs-style ranges string, e.g. ``"0,2-4"``""" values = sorted(values) range_groups = [] for _, g in groupby(enumerate(values), lambda i_x: i_x[0] - i_x[1]): range_groups.append(list(map(itemgetter(1), g))) - range_strings = [] + range_strings: List[str] = [] for group in range_groups: if len(group) == 1: range_strings.append(str(group[0])) @@ -668,7 +886,7 @@ def list_to_ranges(values): return ','.join(range_strings) -def list_to_mask(values, base=0x0): +def list_to_mask(values: List[int], base: int = 0x0) -> int: """Converts the specified list of integer values into a bit mask for those values. Optinally, the list can be applied to an existing mask.""" @@ -677,7 +895,7 @@ def list_to_mask(values, base=0x0): return base -def mask_to_list(mask): +def mask_to_list(mask: int) -> List[int]: """Converts the specfied integer bitmask into a list of indexes of bits that are set in the mask.""" size = len(bin(mask)) - 2 # because of "0b" @@ -685,27 +903,32 @@ def mask_to_list(mask): if mask & (1 << size - i - 1)] -__memo_cache = {} +__memo_cache: Dict[str, Any] = {} + +def reset_memo_cache() -> None: + """ + Clear the global memoization cache used for caching function results. -def reset_memo_cache(): + :returns: None + """ __memo_cache.clear() -def __get_memo_id(obj): +def __get_memo_id(obj: object) -> str: """ An object's id() may be re-used after an object is freed, so it's not sufficiently unique to identify params for the memo cache (two different params may end up with the same id). this attempts to generate a more unique ID string. """ - obj_id = id(obj) + obj_id: int = id(obj) try: return '{}/{}'.format(obj_id, hash(obj)) except TypeError: # obj is not hashable obj_pyobj = ctypes.cast(obj_id, ctypes.py_object) # TODO: Note: there is still a possibility of a clash here. If Two - # different objects get assigned the same ID, an are large and are + # different objects get assigned the same ID, and are large and are # identical in the first thirty two bytes. This shouldn't be much of an # issue in the current application of memoizing Target calls, as it's very # unlikely that a target will get passed large params; but may cause @@ -715,24 +938,33 @@ def __get_memo_id(obj): # undesirable impact on performance. num_bytes = min(ctypes.sizeof(obj_pyobj), 32) obj_bytes = ctypes.string_at(ctypes.addressof(obj_pyobj), num_bytes) - return '{}/{}'.format(obj_id, obj_bytes) + return '{}/{}'.format(obj_id, cast(str, obj_bytes)) -@wrapt.decorator -def memoized(wrapped, instance, args, kwargs): # pylint: disable=unused-argument +def memoized_decor(wrapped: Callable[..., Any], instance: Optional[Any], + args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: # pylint: disable=unused-argument """ - A decorator for memoizing functions and methods. + Decorator helper function for memoizing the results of a function call. + The result is cached based on a key derived from the function's arguments. + Note that this method does not account for changes to mutable arguments. .. warning:: this may not detect changes to mutable types. As long as the memoized function was used with an object as an argument before, the cached result will be returned, even if the structure of the object (e.g. a list) has changed in the mean time. + :param wrapped: The function to be memoized. + :param instance: The instance on which the function is called (if it is a method), or None. + :param args: Tuple of positional arguments passed to the function. + :param kwargs: Dictionary of keyword arguments passed to the function. + :returns: The cached result if available; otherwise, the result from calling the function. + :raises Exception: Any exception raised during the execution of the wrapped function is propagated. + """ - func_id = repr(wrapped) + func_id: str = repr(wrapped) - def memoize_wrapper(*args, **kwargs): - id_string = func_id + ','.join([__get_memo_id(a) for a in args]) + def memoize_wrapper(*args, **kwargs) -> Dict[str, Any]: + id_string: str = func_id + ','.join([__get_memo_id(a) for a in args]) id_string += ','.join('{}={}'.format(k, __get_memo_id(v)) for k, v in kwargs.items()) if id_string not in __memo_cache: @@ -741,8 +973,13 @@ def memoize_wrapper(*args, **kwargs): return memoize_wrapper(*args, **kwargs) + +# create memoized decorator from memoized_decor function +memoized = wrapt.decorator(memoized_decor) + + @contextmanager -def batch_contextmanager(f, kwargs_list): +def batch_contextmanager(f: Callable, kwargs_list: List[Dict[str, Any]]) -> Generator: """ Return a context manager that will call the ``f`` callable with the keyword arguments dict in the given list, in one go. @@ -750,7 +987,6 @@ def batch_contextmanager(f, kwargs_list): :param f: Callable expected to return a context manager. :param kwargs_list: list of kwargs dictionaries to be used to call ``f``. - :type kwargs_list: list(dict) """ with ExitStack() as stack: for kwargs in kwargs_list: @@ -768,9 +1004,9 @@ class nullcontext: :param enter_result: Object that will be bound to the target of the with statement, or `None` if nothing is specified. - :type enter_result: object """ - def __init__(self, enter_result=None): + + def __init__(self, enter_result: Any = None): self.enter_result = enter_result def __enter__(self): @@ -797,21 +1033,43 @@ class tls_property: to that object, like :meth:`_BoundTLSProperty.get_all_values`. Values can be set and deleted as well, which will be a thread-local set. + + :param factory: A callable used to generate the property value. """ @property - def name(self): + def name(self) -> str: + """ + Retrieve the name of the factory function used for this property. + + :returns: The name of the factory function. + """ return self.factory.__name__ - def __init__(self, factory): + def __init__(self, factory: Callable): self.factory = factory # Lock accesses to shared WeakKeyDictionary and WeakSet self.lock = threading.RLock() - def __get__(self, instance, owner=None): + def __get__(self, instance: 'Target', owner: Optional[Type['Target']] = None) -> '_BoundTLSProperty': + """ + Retrieve the thread-local property proxy for the given instance. + + :param instance: The target instance. + :param owner: The class owning the property (optional). + :returns: A bound TLS property proxy. + """ return _BoundTLSProperty(self, instance, owner) - def _get_value(self, instance, owner): + def _get_value(self, instance: 'Target', owner: Optional[Type['Target']]) -> Any: + """ + Retrieve or compute the thread-local value for the given instance. If the value + does not exist, it is created using the factory callable. + + :param instance: The target instance. + :param owner: The class owning the property (optional). + :returns: The thread-local value. + """ tls, values = self._get_tls(instance) try: return tls.value @@ -826,20 +1084,38 @@ def _get_value(self, instance, owner): values.add(obj) return obj - def _get_all_values(self, instance, owner): + def _get_all_values(self, instance: 'Target', owner: Optional[Type['Target']]) -> Set: + """ + Retrieve all thread-local values currently cached for this property in the given instance. + + :param instance: The target instance. + :param owner: The class owning the property (optional). + :returns: A set containing all cached values. + """ with self.lock: # Grab a reference to all the objects at the time of the call by # using a regular set tls, values = self._get_tls(instance=instance) return set(values) - def __set__(self, instance, value): + def __set__(self, instance: 'Target', value): + """ + Set the thread-local value for this property on the given instance. + + :param instance: The target instance. + :param value: The value to set. + """ tls, values = self._get_tls(instance) tls.value = value with self.lock: values.add(value) - def __delete__(self, instance): + def __delete__(self, instance: 'Target'): + """ + Delete the thread-local value for this property from the given instance. + + :param instance: The target instance. + """ tls, values = self._get_tls(instance) with self.lock: try: @@ -850,7 +1126,14 @@ def __delete__(self, instance): values.discard(value) del tls.value - def _get_tls(self, instance): + def _get_tls(self, instance: 'Target') -> Any: + """ + Retrieve the thread-local storage tuple for this property from the instance. + If not present, a new tuple is created and stored. + + :param instance: The target instance. + :returns: A tuple (tls, values) where tls is a thread-local object and values is a WeakSet. + """ dct = instance.__dict__ name = self.name try: @@ -868,40 +1151,56 @@ def _get_tls(self, instance): return tls @property - def basic_property(self): + def basic_property(self) -> property: """ Return a basic property that can be used to access the TLS value without having to call it first. The drawback is that it's not possible to do anything over than getting/setting/deleting. + + :returns: A property object for direct access. """ + def getter(instance, owner=None): prop = self.__get__(instance, owner) return prop() return property(getter, self.__set__, self.__delete__) + class _BoundTLSProperty: """ Simple proxy object to allow either calling it to get the TLS value, or get some other informations by calling methods. + + :param tls_property: The tls_property descriptor. + :param instance: The target instance to which the property is bound. + :param owner: The owning class (optional). """ - def __init__(self, tls_property, instance, owner): + + def __init__(self, tls_property: tls_property, instance: 'Target', owner: Optional[Type['Target']]): self.tls_property = tls_property self.instance = instance self.owner = owner def __call__(self): + """ + Retrieve the thread-local value by calling the underlying tls_property. + + :returns: The thread-local value. + """ return self.tls_property._get_value( instance=self.instance, owner=self.owner, ) - def get_all_values(self): + def get_all_values(self) -> Set[Any]: """ Returns all the thread-local values currently in use in the process for that property for that instance. + + :returns: A set of all thread-local values. """ return self.tls_property._get_all_values( instance=self.instance, @@ -920,9 +1219,20 @@ class InitCheckpointMeta(type): ``is_in_use`` is set to ``True`` when an instance method is being called. This allows to detect reentrance. """ - def __new__(metacls, name, bases, dct, **kwargs): + + def __new__(metacls, name: str, bases: Tuple, dct: Dict, **kwargs: Dict) -> Type: + """ + Create a new class with the augmented __init__ and methods for tracking initialization + and usage. + + :param name: The name of the new class. + :param bases: Base classes for the new class. + :param dct: Dictionary of attributes for the new class. + :param kwargs: Additional keyword arguments. + :returns: The newly created class. + """ cls = super().__new__(metacls, name, bases, dct, **kwargs) - init_f = cls.__init__ + init_f = cls.__init__ # type:ignore @wraps(init_f) def init_wrapper(self, *args, **kwargs): @@ -949,7 +1259,7 @@ def init_wrapper(self, *args, **kwargs): return x - cls.__init__ = init_wrapper + cls.__init__ = init_wrapper # type:ignore # Set the is_in_use attribute to allow external code to detect if the # methods are about to be re-entered. @@ -977,8 +1287,8 @@ def wrapper(self, *args, **kwargs): # Only wrap the methods (exposed as functions), not things like # classmethod or staticmethod if ( - name not in ('__init__', '__new__') and - isinstance(attr, types.FunctionType) + name not in ('__init__', '__new__') and + isinstance(attr, types.FunctionType) ): setattr(cls, name, make_wrapper(attr)) elif isinstance(attr, property): @@ -1000,7 +1310,7 @@ class InitCheckpoint(metaclass=InitCheckpointMeta): pass -def groupby_value(dct): +def groupby_value(dct: Dict[Any, Any]) -> Dict[Tuple[Any, ...], Any]: """ Process the input dict such that all keys sharing the same values are grouped in a tuple, used as key in the returned dict. @@ -1013,7 +1323,8 @@ def groupby_value(dct): } -def safe_extract(tar, path=".", members=None, *, numeric_owner=False): +def safe_extract(tar: 'TarFile', path: str = ".", members: Optional[List['TarInfo']] = None, + *, numeric_owner: bool = False) -> None: """ A wrapper around TarFile.extract all to mitigate CVE-2007-4995 (see https://www.trellix.com/en-us/about/newsroom/stories/research/tarfile-exploiting-the-world.html) @@ -1026,8 +1337,8 @@ def safe_extract(tar, path=".", members=None, *, numeric_owner=False): tar.extractall(path, members, numeric_owner=numeric_owner) -def _is_within_directory(directory, target): +def _is_within_directory(directory: str, target: str) -> bool: abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) diff --git a/devlib/utils/parse_aep.py b/devlib/utils/parse_aep.py old mode 100755 new mode 100644 index 111aa0240..db1496edd --- a/devlib/utils/parse_aep.py +++ b/devlib/utils/parse_aep.py @@ -33,7 +33,9 @@ import signal import sys -logger = logging.getLogger('aep-parser') +from devlib.utils.misc import get_logger + +logger = get_logger('aep-parser') # pylint: disable=attribute-defined-outside-init class AepParser(object): diff --git a/devlib/utils/rendering.py b/devlib/utils/rendering.py index 52d4f00dc..428f13292 100644 --- a/devlib/utils/rendering.py +++ b/devlib/utils/rendering.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ # limitations under the License. # -import logging import os import shutil import sys @@ -24,11 +23,14 @@ from shlex import quote # pylint: disable=redefined-builtin -from devlib.exception import WorkerThreadError, TargetNotRespondingError, TimeoutError +from devlib.exception import WorkerThreadError, TargetNotRespondingError, TimeoutError from devlib.utils.csvutil import csvwriter +from devlib.utils.misc import get_logger +from typing import List, Optional, TYPE_CHECKING, cast +if TYPE_CHECKING: + from devlib.target import Target - -logger = logging.getLogger('rendering') +logger = get_logger('rendering') SurfaceFlingerFrame = namedtuple('SurfaceFlingerFrame', 'desired_present_time actual_present_time frame_ready_time') @@ -38,12 +40,12 @@ class FrameCollector(threading.Thread): - def __init__(self, target, period): + def __init__(self, target: 'Target', period: int): super(FrameCollector, self).__init__() self.target = target self.period = period self.stop_signal = threading.Event() - self.frames = [] + self.frames: List = [] self.temp_file = None self.refresh_period = None @@ -51,7 +53,7 @@ def __init__(self, target, period): self.unresponsive_count = 0 self.last_ready_time = 0 self.exc = None - self.header = None + self.header: Optional[List[str]] = None def run(self): logger.debug('Frame data collection started.') @@ -95,17 +97,18 @@ def process_frames(self, outfile=None): os.unlink(self.temp_file) self.temp_file = None - def write_frames(self, outfile, columns=None): + def write_frames(self, outfile, columns: Optional[List[str]] = None): if columns is None: header = self.header frames = self.frames else: - indexes = [] + indexes: List = [] for c in columns: - if c not in self.header: - msg = 'Invalid column "{}"; must be in {}' - raise ValueError(msg.format(c, self.header)) - indexes.append(self.header.index(c)) + if self.header: + if c not in self.header: + msg = 'Invalid column "{}"; must be in {}' + raise ValueError(msg.format(c, self.header)) + indexes.append(self.header.index(c)) frames = [[f[i] for i in indexes] for f in self.frames] header = columns with csvwriter(outfile) as writer: @@ -128,7 +131,7 @@ class SurfaceFlingerFrameCollector(FrameCollector): def __init__(self, target, period, view, header=None): super(SurfaceFlingerFrameCollector, self).__init__(target, period) self.view = view - self.header = header or SurfaceFlingerFrame._fields + self.header = cast(List[str], header or SurfaceFlingerFrame._fields) def collect_frames(self, wfh): activities = [a for a in self.list() if a.startswith(self.view)] @@ -180,7 +183,7 @@ def _process_trace_parts(self, parts): if len(parts) == 3: frame = SurfaceFlingerFrame(*parts) if not frame.frame_ready_time: - return # "null" frame + return # "null" frame if frame.frame_ready_time <= self.last_ready_time: return # duplicate frame if (frame.frame_ready_time - frame.desired_present_time) > self.drop_threshold: @@ -196,8 +199,8 @@ def _process_trace_parts(self, parts): logger.warning(msg) -def read_gfxinfo_columns(target): - output = target.execute('dumpsys gfxinfo --list framestats') +def read_gfxinfo_columns(target: 'Target') -> List[str]: + output: str = target.execute('dumpsys gfxinfo --list framestats') lines = iter(output.split('\n')) for line in lines: if line.startswith('---PROFILEDATA---'): @@ -222,7 +225,7 @@ def collect_frames(self, wfh): def clear(self): pass - def _init_header(self, header): + def _init_header(self, header: Optional[List[str]]): if header is not None: self.header = header else: diff --git a/devlib/utils/serial_port.py b/devlib/utils/serial_port.py index c4915a959..10bd59592 100644 --- a/devlib/utils/serial_port.py +++ b/devlib/utils/serial_port.py @@ -1,4 +1,4 @@ -# Copyright 2013-2024 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ from pexpect import fdpexpect # pexpect < 4.0.0 does not have fdpexpect module except ImportError: - import fdpexpect + import fdpexpect # type:ignore # Adding pexpect exceptions into this module's namespace @@ -32,6 +32,9 @@ from devlib.exception import HostError +from typing import Optional, TextIO, Union, Tuple +from collections.abc import Generator + class SerialLogger(Logger): @@ -41,17 +44,22 @@ def flush(self): pass -def pulse_dtr(conn, state=True, duration=0.1): +def pulse_dtr(conn: serial.Serial, state: bool = True, duration: float = 0.1) -> None: """Set the DTR line of the specified serial connection to the specified state for the specified duration (note: the initial state of the line is *not* checked.""" - conn.setDTR(state) + conn.dtr = state time.sleep(duration) - conn.setDTR(not state) + conn.dtr = not state # pylint: disable=keyword-arg-before-vararg -def get_connection(timeout, init_dtr=None, logcls=SerialLogger, - logfile=None, *args, **kwargs): +def get_connection(timeout: int, init_dtr: Optional[bool] = None, + logcls=SerialLogger, + logfile: Optional[TextIO] = None, *args, **kwargs) -> Tuple[fdpexpect.fdspawn, + serial.Serial]: + """ + get the serial connection + """ if init_dtr is not None: kwargs['dsrdtr'] = True try: @@ -59,10 +67,10 @@ def get_connection(timeout, init_dtr=None, logcls=SerialLogger, except serial.SerialException as e: raise HostError(str(e)) if init_dtr is not None: - conn.setDTR(init_dtr) + conn.dtr = init_dtr conn.nonblocking() - conn.flushOutput() - target = fdpexpect.fdspawn(conn.fileno(), timeout=timeout, logfile=logfile) + conn.reset_output_buffer() + target: fdpexpect.fdspawn = fdpexpect.fdspawn(conn.fileno(), timeout=timeout, logfile=logfile) target.logfile_read = logcls('read') target.logfile_send = logcls('send') @@ -73,15 +81,16 @@ def get_connection(timeout, init_dtr=None, logcls=SerialLogger, # corruption. The delay prevents that. tsln = target.sendline - def sendline(x): - tsln(x) + def sendline(s: Union[str, bytes]) -> int: + ret: int = tsln(s) time.sleep(0.1) + return ret target.sendline = sendline return target, conn -def write_characters(conn, line, delay=0.05): +def write_characters(conn: fdpexpect.fdspawn, line: str, delay: float = 0.05) -> None: """Write a single line out to serial charcter-by-character. This will ensure that nothing will be dropped for longer lines.""" line = line.rstrip('\r\n') @@ -93,8 +102,10 @@ def write_characters(conn, line, delay=0.05): # pylint: disable=keyword-arg-before-vararg @contextmanager -def open_serial_connection(timeout, get_conn=False, init_dtr=None, - logcls=SerialLogger, *args, **kwargs): +def open_serial_connection(timeout: int, get_conn: bool = False, + init_dtr: Optional[bool] = None, + logcls=SerialLogger, *args, **kwargs) -> Generator[Union[Tuple[fdpexpect.fdspawn, serial.Serial], + fdpexpect.fdspawn], None, None]: """ Opens a serial connection to a device. @@ -112,11 +123,11 @@ def open_serial_connection(timeout, get_conn=False, init_dtr=None, See: http://pexpect.sourceforge.net/pexpect.html """ - target, conn = get_connection(timeout, init_dtr=init_dtr, - logcls=logcls, *args, **kwargs) + target, conn = get_connection(timeout, init_dtr, + logcls, *args, **kwargs) if get_conn: - target_and_conn = (target, conn) + target_and_conn: Union[Tuple[fdpexpect.fdspawn, serial.Serial], fdpexpect.fdspawn] = (target, conn) else: target_and_conn = target diff --git a/devlib/utils/ssh.py b/devlib/utils/ssh.py index e64f67bd7..d26fd56d7 100644 --- a/devlib/utils/ssh.py +++ b/devlib/utils/ssh.py @@ -1,4 +1,4 @@ -# Copyright 2014-2024 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ # limitations under the License. # - import os import stat import logging @@ -31,21 +30,18 @@ import functools import shutil from shlex import quote - -from paramiko.client import SSHClient, AutoAddPolicy, RejectPolicy -import paramiko.ssh_exception -from scp import SCPClient -# By default paramiko is very verbose, including at the INFO level -logging.getLogger("paramiko").setLevel(logging.WARNING) - # pylint: disable=import-error,wrong-import-position,ungrouped-imports,wrong-import-order -import pexpect +import pexpect # type: ignore try: - from pexpect import pxssh + from pexpect import pxssh # type: ignore # pexpect < 4.0.0 does not have a pxssh module except ImportError: - import pxssh + import pxssh # type: ignore +from paramiko.client import SSHClient, AutoAddPolicy, RejectPolicy, MissingHostKeyPolicy +import paramiko.ssh_exception +from scp import SCPClient + from pexpect import EOF, TIMEOUT, spawn @@ -56,48 +52,107 @@ TargetTransientCalledProcessError, TargetStableCalledProcessError) from devlib.utils.misc import (which, strip_bash_colors, check_output, - sanitize_cmd_template, memoized, redirect_streams) + sanitize_cmd_template, memoized, redirect_streams, + get_logger) from devlib.utils.types import boolean -from devlib.connection import ConnectionBase, ParamikoBackgroundCommand, SSHTransferHandle +from devlib.connection import ConnectionBase, ParamikoBackgroundCommand, SSHTransferHandle, TransferManager +from typing import (Optional, TYPE_CHECKING, Tuple, cast, + Callable, Union, List, Sized, Dict, + Pattern, Type) +from collections.abc import Generator +from typing_extensions import Literal +from io import BufferedReader, BufferedWriter +if TYPE_CHECKING: + from devlib.utils.annotation_helpers import SubprocessCommand + from devlib.platform import Platform + from paramiko.transport import Transport + from paramiko.channel import Channel, ChannelStderrFile, ChannelFile, ChannelStdinFile + from paramiko.sftp_client import SFTPClient + from logging import Logger + from subprocess import Popen +# By default paramiko is very verbose, including at the INFO level +get_logger("paramiko").setLevel(logging.WARNING) # Empty prompt with -p '' to avoid adding a leading space to the output. DEFAULT_SSH_SUDO_COMMAND = "sudo -k -p '' -S -- sh -c {}" +""" +Default command template for acquiring sudo privileges over SSH. +""" + +OutStreamType = Tuple[Union[Optional['BufferedReader'], int], Union[Optional['BufferedWriter'], int, bytes]] +""" +Represents a pair of read-end and write-end streams used for background command output. +""" + +ChannelFiles = Tuple['ChannelStdinFile', 'ChannelFile', 'ChannelStderrFile'] +""" +Represents a triple of Paramiko channel file objects for stdin, stdout, stderr. +""" class _SSHEnv: + """ + Provides resolved paths to SSH-related utilities. + + The main usage includes: + - ``ssh`` for connecting to remote hosts, + - ``scp`` for file transfers, + - ``sshpass`` if password authentication is needed. + + The paths are discovered on the host system using :func:`which`. + """ @functools.lru_cache(maxsize=None) - def get_path(self, tool): + def get_path(self, tool: str) -> str: + """ + Return the full path to the specified ``tool`` (one of ``ssh``, ``scp``, or ``sshpass``). + + :param tool: Name of the executable to look for. + :returns: The full path to the requested tool. + :raises HostError: If the tool cannot be found in PATH. + """ if tool in {'ssh', 'scp', 'sshpass'}: - path = which(tool) + path: Optional[str] = which(tool) if path: return path else: raise HostError(f'OpenSSH must be installed on the host: could not find {tool} command') else: raise AttributeError(f"Tool '{tool}' is not supported") + + _SSH_ENV = _SSHEnv() -logger = logging.getLogger('ssh') -gem5_logger = logging.getLogger('gem5-connection') +logger: 'Logger' = get_logger('ssh') +gem5_logger: 'Logger' = get_logger('gem5-connection') @contextlib.contextmanager -def _handle_paramiko_exceptions(command=None): +def _handle_paramiko_exceptions(command: Optional['SubprocessCommand'] = None) -> Generator: + """ + A context manager that catches exceptions from Paramiko calls, raising devlib-friendly + exceptions where appropriate. + + :param command: Optional command string for context in exception messages. + :raises TargetNotRespondingError: If connection issues are detected. + :raises TargetStableError: If there is an SSH logic or host key error. + :raises TargetTransientError: If an SSH logic error suggests a transient condition. + :raises TimeoutError: If a socket timeout occurs. + """ try: yield except paramiko.ssh_exception.NoValidConnectionsError as e: raise TargetNotRespondingError('Connection lost: {}'.format(e)) - except paramiko.ssh_exception.AuthenticationException as e: - raise TargetStableError('Could not authenticate: {}'.format(e)) except paramiko.ssh_exception.BadAuthenticationType as e: raise TargetStableError('Bad authentication type: {}'.format(e)) + except paramiko.ssh_exception.PasswordRequiredException as e: + raise TargetStableError('Please unlock the private key file: {}'.format(e)) + except paramiko.ssh_exception.AuthenticationException as e: + raise TargetStableError('Could not authenticate: {}'.format(e)) except paramiko.ssh_exception.BadHostKeyException as e: raise TargetStableError('Bad host key: {}'.format(e)) except paramiko.ssh_exception.ChannelException as e: raise TargetStableError('Could not open an SSH channel: {}'.format(e)) - except paramiko.ssh_exception.PasswordRequiredException as e: - raise TargetStableError('Please unlock the private key file: {}'.format(e)) except paramiko.ssh_exception.ProxyCommandFailure as e: raise TargetStableError('Proxy command failure: {}'.format(e)) except paramiko.ssh_exception.SSHException as e: @@ -106,7 +161,23 @@ def _handle_paramiko_exceptions(command=None): raise TimeoutError(command, output=None) -def _read_paramiko_streams(stdout, stderr, select_timeout, callback, init, chunk_size=int(1e42)): +def _read_paramiko_streams(stdout: 'ChannelFile', stderr: 'ChannelStderrFile', + select_timeout: Optional[float], callback: Callable, + init: List[bytes], chunk_size=int(1e42)) -> Tuple[Optional[List[bytes]], int]: + """ + Read data from Paramiko's stdout/stderr streams until the channel closes. + Applies an optional callback to each chunk read for each stream. + + :param stdout: Paramiko file-like object for stdout. + :param stderr: Paramiko file-like object for stderr. + :param select_timeout: Maximum time (seconds) to block when reading from the channel. + :param callback: A function receiving (callback_state, 'stdout' or 'stderr', chunk). + Must return the new callback_state for subsequent calls. + :param init: Initial callback state. + :param chunk_size: Maximum chunk size in bytes for each read. Defaults to a large integer. + :returns: A tuple of (final_callback_state, exit_code). + :raises Exception: If the callback itself raises an exception. + """ try: return _read_paramiko_streams_internal(stdout, stderr, select_timeout, callback, init, chunk_size) finally: @@ -118,11 +189,21 @@ def _read_paramiko_streams(stdout, stderr, select_timeout, callback, init, chunk stdout.channel.close() -def _read_paramiko_streams_internal(stdout, stderr, select_timeout, callback, init, chunk_size): +def _read_paramiko_streams_internal(stdout: 'ChannelFile', stderr: 'ChannelStderrFile', + select_timeout: Optional[float], + callback: Callable[[Optional[List[bytes]], str, bytes], List[bytes]], + init: Optional[List[bytes]], chunk_size: int) -> Tuple[Optional[List[bytes]], int]: + """ + Internal helper for :func:`_read_paramiko_streams`. + """ channel = stdout.channel assert stdout.channel is stderr.channel - def read_channel(callback_state): + def read_channel(callback_state: Optional[List[bytes]]) -> Tuple[Optional[Exception], Optional[List[bytes]]]: + """ + read data from the channel, stdout or stderr + """ + read_list: List['Channel'] read_list, _, _ = select.select([channel], [], [], select_timeout) for desc in read_list: for ready, recv, name in ( @@ -130,7 +211,7 @@ def read_channel(callback_state): (desc.recv_stderr_ready(), desc.recv_stderr, 'stderr') ): if ready: - chunk = recv(chunk_size) + chunk: bytes = recv(chunk_size) if chunk: try: callback_state = callback(callback_state, name, chunk) @@ -139,7 +220,11 @@ def read_channel(callback_state): return (None, callback_state) - def read_all_channel(callback=None, callback_state=None): + def read_all_channel(callback: Optional[Callable[[Optional[List[bytes]], str, bytes], List[bytes]]] = None, + callback_state: Optional[List[bytes]] = None) -> Optional[List[bytes]]: + """ + read data from both stdout and stderr + """ for stream, name in ((stdout, 'stdout'), (stderr, 'stderr')): try: chunk = stream.read() @@ -151,7 +236,7 @@ def read_all_channel(callback=None, callback_state=None): return callback_state - callback_excep = None + callback_excep: Optional[Exception] = None try: callback_state = init while not channel.exit_status_ready(): @@ -176,7 +261,13 @@ def read_all_channel(callback=None, callback_state=None): return (callback_state, exit_code) -def _resolve_known_hosts(strict_host_check): +def _resolve_known_hosts(strict_host_check: Optional[Union[bool, str, os.PathLike]]) -> str: + """ + Compute a path to the known_hosts file based on ``strict_host_check``. + + :param strict_host_check: If True, uses ~/.ssh/known_hosts; if a path is given, uses that path; if False, returns /dev/null. + :returns: Absolute path to the known_hosts file (or '/dev/null'). + """ if strict_host_check: if isinstance(strict_host_check, (str, os.PathLike)): path = Path(strict_host_check) @@ -188,13 +279,25 @@ def _resolve_known_hosts(strict_host_check): return str(path.resolve()) -def telnet_get_shell(host, - username, - password=None, - port=None, - timeout=10, - original_prompt=None): - start_time = time.time() +def telnet_get_shell(host: str, + username: str, + password: Optional[str] = None, + port: Optional[int] = None, + timeout: float = 10, + original_prompt: Optional[str] = None) -> 'TelnetPxssh': + """ + Obtain a Telnet shell by calling :class:`TelnetPxssh`. + + :param host: The host name or IP address for the Telnet connection. + :param username: The username for Telnet login. + :param password: Password for Telnet login, or None if no password is needed. + :param port: TCP port for Telnet. Defaults to 23 if unspecified. + :param timeout: Time in seconds to wait for the initial connection. + :param original_prompt: Regex for matching the shell prompt if it differs from default. + :returns: A TelnetPxssh object for interacting with the shell. + :raises TargetTransientError: If connection fails repeatedly within the timeout period. + """ + start_time: float = time.time() while True: conn = TelnetPxssh(original_prompt=original_prompt) @@ -217,25 +320,47 @@ def telnet_get_shell(host, class TelnetPxssh(pxssh.pxssh): # pylint: disable=arguments-differ + """ + A specialized Telnet-based shell session class, derived from :class:`pxssh.pxssh`. - def __init__(self, original_prompt): + :param original_prompt: A regex pattern for the shell's default prompt. + """ + def __init__(self, original_prompt: Optional[str]): super(TelnetPxssh, self).__init__() self.original_prompt = original_prompt or r'[#$]' - def login(self, server, username, password='', login_timeout=10, - auto_prompt_reset=True, sync_multiplier=1, port=23): - args = ['telnet'] + def login(self, server: str, username: str, password: Optional[str] = '', login_timeout: float = 10, + auto_prompt_reset: bool = True, sync_multiplier: int = 1, port: Optional[int] = 23) -> bool: + """ + Attempt Telnet login, specifying a host, username, and optional password. + + :param server: Host name or IP address. + :param username: Username to log in with. + :param password: Password, if any, or empty string. + :param login_timeout: Time in seconds to wait for login prompts before failing. + :param auto_prompt_reset: If True, attempt to detect and set a unique prompt. + :param sync_multiplier: Adjust how aggressively pxssh synchronizes prompt detection. + :param port: Telnet port, default 23. + :returns: True if login was successful. + :raises pxssh.ExceptionPxssh: If login fails or the password is incorrect. + :raises TIMEOUT: If no password prompt is shown within the timeout. + """ + args: List[str] = ['telnet'] if username is not None: args += ['-l', username] args += [server, str(port)] - cmd = ' '.join(args) - - spawn._spawn(self, cmd) # pylint: disable=protected-access + cmd: str = ' '.join(args) + # FIXME - Modified the access to _spawn protected method and instead use public method of pexpect. + # need to see if there is any issue with the replacement + # Spawn the command + child = pexpect.spawn(cmd) + # Wait for the command to complete + child.expect(pexpect.EOF) try: - i = self.expect('(?i)(?:password)', timeout=login_timeout) + i: int = self.expect('(?i)(?:password)', timeout=login_timeout) if i == 0: - self.sendline(password) + self.sendline(password or '') i = self.expect([self.original_prompt, 'Login incorrect'], timeout=login_timeout) if i: raise pxssh.ExceptionPxssh('could not log in: password was incorrect') @@ -259,18 +384,22 @@ def login(self, server, username, password='', login_timeout=10, return True -def check_keyfile(keyfile): +def check_keyfile(keyfile: str) -> str: """ keyfile must have the right access premissions in order to be useable. If the specified file doesn't, create a temporary copy and set the right permissions for that. Returns either the ``keyfile`` (if the permissions on it are correct) or the path to a temporary copy with the right permissions. + + :param keyfile: The path to the SSH private key file. + :returns: Either the original ``keyfile`` (if it already has 0600 perms) + or a temporary copy path with corrected permissions. """ - desired_mask = stat.S_IWUSR | stat.S_IRUSR - actual_mask = os.stat(keyfile).st_mode & 0xFF + desired_mask: int = stat.S_IWUSR | stat.S_IRUSR + actual_mask: int = os.stat(keyfile).st_mode & 0xFF if actual_mask != desired_mask: - tmp_file = os.path.join(tempfile.gettempdir(), os.path.basename(keyfile)) + tmp_file: str = os.path.join(tempfile.gettempdir(), os.path.basename(keyfile)) shutil.copy(keyfile, tmp_file) os.chmod(tmp_file, desired_mask) return tmp_file @@ -280,17 +409,73 @@ def check_keyfile(keyfile): class SshConnectionBase(ConnectionBase): """ - Base class for SSH connections. + Base class for SSH-derived connections, providing shared functionality + like verifying keyfile permissions, tracking host info, and more. + + :param host: The SSH target hostname or IP address. + :param username: Username to log in as. + :param password: Password for the SSH connection, or None if key-based auth is used. + :param keyfile: Path to an SSH private key if using key-based auth. + :param port: TCP port for the SSH server. Defaults to 22 if unspecified. + :param platform: A devlib.platform.Platform instance describing the device. + :param sudo_cmd: A template string for granting sudo privileges (e.g. "sudo -S sh -c {}"). + :param strict_host_check: If True, host key checking is enforced using a known_hosts file. + If a string/path is supplied, that path is used as known_hosts. If False, host keys are not checked. + :param poll_transfers: If True, uses :class:`TransferManager` to poll file transfers. + :param start_transfer_poll_delay: Delay in seconds before the first poll of a new file transfer. + :param total_transfer_timeout: If a file transfer exceeds this many seconds, it is canceled. + :param transfer_poll_period: Interval (seconds) between file transfer progress checks. """ + def __init__(self, + host: str, + username: str, + password: Optional[str] = None, + keyfile: Optional[str] = None, + port: Optional[int] = None, + platform: Optional['Platform'] = None, + sudo_cmd: str = DEFAULT_SSH_SUDO_COMMAND, + strict_host_check: Union[bool, str, os.PathLike] = True, + poll_transfers: bool = False, + start_transfer_poll_delay: int = 30, + total_transfer_timeout: int = 3600, + transfer_poll_period: int = 30, + ): + super().__init__( + poll_transfers=poll_transfers, + start_transfer_poll_delay=start_transfer_poll_delay, + total_transfer_timeout=total_transfer_timeout, + transfer_poll_period=transfer_poll_period, + ) + self._connected_as_root: Optional[bool] = None + self.host = host + self.username = username + self.password = password + self.keyfile = check_keyfile(keyfile) if keyfile else keyfile + self.port = port + self.sudo_cmd = sanitize_cmd_template(sudo_cmd) + self.platform = platform + self.strict_host_check = strict_host_check + logger.debug('Logging in {}@{}'.format(username, host)) - default_timeout = 10 + default_timeout: int = 10 + """ + Default timeout in seconds for SSH operations if not otherwise specified. + """ @property - def name(self): + def name(self) -> str: + """ + :returns: A string identifying the host (e.g. the IP or hostname). + """ return self.host @property - def connected_as_root(self): + def connected_as_root(self) -> bool: + """ + Indicates if the current user on the remote SSH session is root (uid=0). + + :returns: True if root, else False. + """ if self._connected_as_root is None: try: result = self.execute('id', as_root=False) @@ -303,58 +488,84 @@ def connected_as_root(self): @connected_as_root.setter def connected_as_root(self, state): - self._connected_as_root = state + """ + Explicitly set the known state of root usage on this connection. - def __init__(self, - host, - username, - password=None, - keyfile=None, - port=None, - platform=None, - sudo_cmd=DEFAULT_SSH_SUDO_COMMAND, - strict_host_check=True, - - poll_transfers=False, - start_transfer_poll_delay=30, - total_transfer_timeout=3600, - transfer_poll_period=30, - ): - super().__init__( - poll_transfers=poll_transfers, - start_transfer_poll_delay=start_transfer_poll_delay, - total_transfer_timeout=total_transfer_timeout, - transfer_poll_period=transfer_poll_period, - ) - self._connected_as_root = None - self.host = host - self.username = username - self.password = password - self.keyfile = check_keyfile(keyfile) if keyfile else keyfile - self.port = port - self.sudo_cmd = sanitize_cmd_template(sudo_cmd) - self.platform = platform - self.strict_host_check = strict_host_check - logger.debug('Logging in {}@{}'.format(username, host)) + :param state: True if effectively root, False otherwise. + """ + self._connected_as_root = state class SshConnection(SshConnectionBase): + """ + A connection to a device on the network over SSH. + + :param host: SSH host to which to connect + :param username: username for SSH login + :param password: password for the SSH connection + + .. note:: To connect to a system without a password this + parameter should be set to an empty string otherwise + ssh key authentication will be attempted. + .. note:: In order to user password-based authentication, + ``sshpass`` utility must be installed on the + system. + + :param keyfile: Path to the SSH private key to be used for the connection. + + .. note:: ``keyfile`` and ``password`` can't be specified + at the same time. + + :param port: TCP port on which SSH server is listening on the remote device. + Omit to use the default port. + :param timeout: Timeout for the connection in seconds. If a connection + cannot be established within this time, an error will be + raised. + :param platform: Specify the platform to be used. The generic :class:`~devlib.platform.Platform` + class is used by default. + :param sudo_cmd: Specify the format of the command used to grant sudo access. + :param strict_host_check: Specify the ssh connection parameter + ``StrictHostKeyChecking``. If a path is passed + rather than a boolean, it will be taken for a + ``known_hosts`` file. Otherwise, the default + ``$HOME/.ssh/known_hosts`` will be used. + :param use_scp: If True, prefer using the scp binary for file transfers instead of SFTP. + :param poll_transfers: Specify whether file transfers should be polled. Polling + monitors the progress of file transfers and periodically + checks whether they have stalled, attempting to cancel + the transfers prematurely if so. + :param start_transfer_poll_delay: If transfers are polled, specify the length of + time after a transfer has started before polling + should start. + :param total_transfer_timeout: If transfers are polled, specify the total amount of time + to elapse before the transfer is cancelled, regardless + of its activity. + :param transfer_poll_period: If transfers are polled, specify the period at which + the transfers are sampled for activity. Too small values + may cause the destination size to appear the same over + one or more sample periods, causing improper transfer + cancellation. + + :raises TargetNotRespondingError: If the SSH server cannot be reached. + :raises HostError: If the password or keyfile are invalid, or scp/sftp cannot be opened. + :raises TargetStableError: If authentication fails or paramiko encounters an unrecoverable error. + """ # pylint: disable=unused-argument,super-init-not-called def __init__(self, - host, - username, - password=None, - keyfile=None, - port=22, - timeout=None, - platform=None, - sudo_cmd=DEFAULT_SSH_SUDO_COMMAND, - strict_host_check=True, - use_scp=False, - poll_transfers=False, - start_transfer_poll_delay=30, - total_transfer_timeout=3600, - transfer_poll_period=30, + host: str, + username: str, + password: Optional[str] = None, + keyfile: Optional[str] = None, + port: Optional[int] = 22, + timeout: Optional[int] = None, + platform: Optional['Platform'] = None, + sudo_cmd: str = DEFAULT_SSH_SUDO_COMMAND, + strict_host_check: Union[bool, str, os.PathLike] = True, + use_scp: bool = False, + poll_transfers: bool = False, + start_transfer_poll_delay: int = 30, + total_transfer_timeout: int = 3600, + transfer_poll_period: int = 30, ): super().__init__( @@ -381,7 +592,7 @@ def __init__(self, else: logger.debug('Using SFTP for file transfer') - self.client = None + self.client: Optional[SSHClient] = None try: self.client = self._make_client() @@ -391,7 +602,7 @@ def __init__(self, # everything will work as long as we login as root). If sudo is still # needed, it will explode when someone tries to use it. After all, the # user might not be interested in being root at all. - self._sudo_needs_password = ( + self._sudo_needs_password: bool = ( 'NEED_PASSWORD' in self.execute( # sudo -n is broken on some versions on MacOSX, revisit that if @@ -410,13 +621,16 @@ def __init__(self, finally: raise e - def _make_client(self): + def _make_client(self) -> SSHClient: + """ + Create, connect and return a class:SSHClient object + """ if self.strict_host_check: - policy = RejectPolicy + policy: Type[MissingHostKeyPolicy] = RejectPolicy else: policy = AutoAddPolicy # Only try using SSH keys if we're not using a password - check_ssh_keys = self.password is None + check_ssh_keys: bool = self.password is None with _handle_paramiko_exceptions(): client = SSHClient() @@ -427,7 +641,7 @@ def _make_client(self): client.set_missing_host_key_policy(policy) client.connect( hostname=self.host, - port=self.port, + port=self.port or 0, username=self.username, password=self.password, key_filename=self.keyfile, @@ -438,19 +652,28 @@ def _make_client(self): return client - def _make_channel(self): + def _make_channel(self) -> Optional['Channel']: + """ + The Transport class in the Paramiko library is a core component for handling SSH connections. + It attaches to a stream (usually a socket), negotiates an encrypted session, authenticates, + and then creates stream tunnels, called channels, across the session. + Multiple channels can be multiplexed across a single session + """ with _handle_paramiko_exceptions(): - transport = self.client.get_transport() - channel = transport.open_session() + transport: Optional['Transport'] = self.client.get_transport() if self.client else None + channel = transport.open_session() if transport else None return channel # Limit the number of opened channels to a low number, since some servers # will reject more connections request. For OpenSSH, this is controlled by # the MaxSessions config. @functools.lru_cache(maxsize=1) - def _cached_get_sftp(self): + def _cached_get_sftp(self) -> Optional['SFTPClient']: + """ + get the cached sftp channel to avoid opening too many channels to server + """ try: - sftp = self.client.open_sftp() + sftp: Optional['SFTPClient'] = self.client.open_sftp() if self.client else None except paramiko.ssh_exception.SSHException as e: if 'EOF during negotiation' in str(e): raise TargetStableError('The SSH server does not support SFTP. Please install and enable appropriate module.') from e @@ -458,21 +681,41 @@ def _cached_get_sftp(self): raise return sftp - def _get_sftp(self, timeout): - sftp = self._cached_get_sftp() - sftp.get_channel().settimeout(timeout) + def _get_sftp(self, timeout: Optional[float]) -> Optional['SFTPClient']: + """ + get the cached sftp channel and set a channel timeout for read write operations. + returns the channel with the timeout set + """ + sftp: Optional['SFTPClient'] = self._cached_get_sftp() + if sftp: + channel = sftp.get_channel() + if channel: + channel.settimeout(timeout) return sftp @functools.lru_cache() - def _get_scp(self, timeout, callback=lambda *_: None): - cb = lambda _, to_transfer, transferred: callback(to_transfer, transferred) - return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=cb) - - def _push_file(self, sftp, src, dst, callback): + def _get_scp(self, timeout: float, callback: Callable[..., None] = lambda *_: None) -> Optional[SCPClient]: + """ + get scp client as a class:SCPClient object + """ + cb: Callable[[bytes, int, int], None] = lambda _, to_transfer, transferred: callback(to_transfer, transferred) + if self.client: + transport: Optional['Transport'] = self.client.get_transport() + if transport: + return SCPClient(transport, socket_timeout=timeout, progress=cb) + return None + + def _push_file(self, sftp: 'SFTPClient', src: str, dst: str, callback: Optional[Callable]) -> None: + """ + push file to device via SFTP client + """ sftp.put(src, dst, callback=callback) @classmethod - def _path_exists(cls, sftp, path): + def _path_exists(cls, sftp: 'SFTPClient', path: str) -> bool: + """ + check whether the path exists on the device + """ try: sftp.lstat(path) except FileNotFoundError: @@ -480,12 +723,16 @@ def _path_exists(cls, sftp, path): else: return True - def _push_folder(self, sftp, src, dst, callback): + def _push_folder(self, sftp: 'SFTPClient', src: str, dst: str, + callback: Optional[Callable]) -> None: + """ + push a folder into device via SFTP client + """ sftp.mkdir(dst) for entry in os.scandir(src): - name = entry.name - src_path = os.path.join(src, name) - dst_path = os.path.join(dst, name) + name: str = entry.name + src_path: str = os.path.join(src, name) + dst_path: str = os.path.join(dst, name) if entry.is_dir(): push = self._push_folder else: @@ -493,12 +740,20 @@ def _push_folder(self, sftp, src, dst, callback): push(sftp, src_path, dst_path, callback) - def _push_path(self, sftp, src, dst, callback=None): + def _push_path(self, sftp: 'SFTPClient', src: str, dst: str, + callback: Optional[Callable] = None) -> None: + """ + push a path via sftp client + """ logger.debug('Pushing via sftp: {} -> {}'.format(src, dst)) push = self._push_folder if os.path.isdir(src) else self._push_file push(sftp, src, dst, callback) - def _pull_file(self, sftp, src, dst, callback): + def _pull_file(self, sftp: 'SFTPClient', src: str, dst: str, + callback: Optional[Callable]) -> None: + """ + pull a file via sftp client + """ try: sftp.get(src, dst, callback=callback) except Exception as e: @@ -512,20 +767,28 @@ def _pull_file(self, sftp, src, dst, callback): pass raise e - def _pull_folder(self, sftp, src, dst, callback): + def _pull_folder(self, sftp: 'SFTPClient', src: str, dst: str, + callback: Optional[Callable]) -> None: + """ + pull a folder via sftp client + """ os.makedirs(dst) for fileattr in sftp.listdir_attr(src): filename = fileattr.filename src_path = os.path.join(src, filename) dst_path = os.path.join(dst, filename) - if stat.S_ISDIR(fileattr.st_mode): + if stat.S_ISDIR(fileattr.st_mode or 0): pull = self._pull_folder else: pull = self._pull_file pull(sftp, src_path, dst_path, callback) - def _pull_path(self, sftp, src, dst, callback=None): + def _pull_path(self, sftp: 'SFTPClient', src: str, dst: str, + callback: Optional[Callable] = None) -> None: + """ + pull a path from the device via sftp client + """ logger.debug('Pulling via sftp: {} -> {}'.format(src, dst)) try: self._pull_file(sftp, src, dst, callback) @@ -533,58 +796,105 @@ def _pull_path(self, sftp, src, dst, callback=None): # Maybe that was a directory, so retry as such self._pull_folder(sftp, src, dst, callback) - def push(self, sources, dest, timeout=None): + def push(self, sources: List[str], dest: str, timeout: Optional[int] = None) -> None: + """ + Transfer (push) one or more files from the host to the remote target. + + :param sources: A List of paths on the host system to be pushed. + :param dest: Destination path on the remote device. If multiple sources, it should be a directory. + :param timeout: Optional time limit in seconds for each file transfer. If exceeded, raises an error. + :raises TargetStableError: If uploading fails or the remote host is not ready. + :raises HostError: If local scp or sftp usage fails. + """ self._push_pull('push', sources, dest, timeout) - def pull(self, sources, dest, timeout=None): + def pull(self, sources: List[str], dest: str, timeout: Optional[int] = None) -> None: + """ + Transfer (pull) one or more files from the remote target to the host. + + :param sources: A tuple of paths on the remote device to be pulled. + :param dest: Destination path on the host. If multiple sources, it should be a directory. + :param timeout: Optional time limit in seconds for each file transfer. + :raises TargetStableError: If downloading fails on the remote side. + :raises HostError: If local scp or sftp usage fails. + """ self._push_pull('pull', sources, dest, timeout) - def _push_pull(self, action, sources, dest, timeout): + def _push_pull(self, action: Union[Literal['push'], Literal['pull']], + sources: List[str], dest: str, timeout: Optional[int]) -> None: + """ + Internal helper to handle both push and pull operations, optionally + using SCP or SFTP, with optional timeouts or polling. + + :param action: Either 'push' or 'pull', indicating the transfer direction. + :param sources: Paths to upload/download. + :param dest: The destination path, on host (for pull) or remote (for push). + :param timeout: If set, a per-file time limit (seconds) for the operation. + :raises TargetStableError: If the remote side fails or scp/sftp commands fail. + :raises HostError: If local environment or tools are unavailable. + """ if action not in ['push', 'pull']: raise ValueError("Action must be either `push` or `pull`") - def make_handle(obj): + def make_handle(obj: Union[SCPClient, 'SFTPClient']): handle = SSHTransferHandle(obj, manager=self.transfer_manager) - cm = self.transfer_manager.manage(sources, dest, action, handle) + cm = cast(TransferManager, self.transfer_manager).manage(sources, dest, action, handle) return (handle, cm) # If timeout is set if timeout is not None: if self.use_scp: - scp = self._get_scp(timeout) - scp_cmd = getattr(scp, 'put' if action == 'push' else 'get') - scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest) + scp: Optional[SCPClient] = self._get_scp(timeout) + scp_cmd: Callable = getattr(scp, 'put' if action == 'push' else 'get') + scp_msg: str = '{}ing via scp: {} -> {}'.format(action, sources, dest) logger.debug(scp_msg.capitalize()) scp_cmd(sources, dest, recursive=True) else: - sftp = self._get_sftp(timeout) - sftp_cmd = getattr(self, '_' + action + '_path') + sftp: Optional['SFTPClient'] = self._get_sftp(timeout) + sftp_cmd: Callable = getattr(self, '_' + action + '_path') with _handle_paramiko_exceptions(): for source in sources: sftp_cmd(sftp, source, dest) # No timeout elif self.use_scp: - def progress_cb(*args, **kwargs): + def progress_cb(*args, **kwargs) -> None: return handle.progress_cb(*args, **kwargs) scp = self._get_scp(timeout, callback=progress_cb) - handle, cm = make_handle(scp) + if scp: + handle, cm = make_handle(scp) scp_cmd = getattr(scp, 'put' if action == 'push' else 'get') - with _handle_paramiko_exceptions(), cm: + with _handle_paramiko_exceptions(), cast(contextlib._GeneratorContextManager, cm): scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest) logger.debug(scp_msg.capitalize()) scp_cmd(sources, dest, recursive=True) else: sftp = self._get_sftp(timeout) - handle, cm = make_handle(sftp) - sftp_cmd = getattr(self, '_' + action + '_path') - with _handle_paramiko_exceptions(), cm: - for source in sources: - sftp_cmd(sftp, source, dest, callback=handle.progress_cb) - - def execute(self, command, timeout=None, check_exit_code=True, - as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument + if sftp: + handle, cm = make_handle(sftp) + sftp_cmd = getattr(self, '_' + action + '_path') + with _handle_paramiko_exceptions(), cm: + for source in sources: + sftp_cmd(sftp, source, dest, callback=handle.progress_cb) + + def execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, check_exit_code: bool = True, + as_root: Optional[bool] = False, strip_colors: bool = True, will_succeed: bool = False) -> str: # pylint: disable=unused-argument + """ + Run a command synchronously on the remote machine, capturing its output. + By default, raises an exception if the command returns a non-zero exit code. + + :param command: The shell command to run, as a string or SubprocessCommand object. + :param timeout: Maximum time in seconds to wait for completion. If None, uses a default or indefinite wait. + :param check_exit_code: If True, raise an error if the command's exit code is non-zero. + :param as_root: If True, attempt to run the command via sudo unless already connected as root. + :param strip_colors: If True, remove ANSI color codes from the captured output. + :param will_succeed: If True, treat a non-zero exit code as transient instead of stable. + :returns: The combined stdout/stderr of the command. + :raises TargetTransientCalledProcessError: If `check_exit_code=True` and the command fails while `will_succeed=True`. + :raises TargetStableCalledProcessError: If `check_exit_code=True` and the command fails while `will_succeed=False`. + :raises TargetStableError: If a stable SSH or environment error occurs. + """ if command == '': return '' try: @@ -608,26 +918,57 @@ def execute(self, command, timeout=None, check_exit_code=True, ) return output - def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + def background(self, command: 'SubprocessCommand', stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, as_root: Optional[bool] = False) -> ParamikoBackgroundCommand: + """ + Execute a command in the background on the remote host, returning a handle + to manage it. The command runs until completion or cancellation. + + :param command: The command to run. + :param stdout: Where to direct the command's stdout (default: subprocess.PIPE). + :param stderr: Where to direct the command's stderr (default: subprocess.PIPE). + :param as_root: If True, attempt to run under sudo unless already root. + :returns: A :class:`ParamikoBackgroundCommand` instance to manage or query the process. + :raises TargetStableError: If channel creation fails or paramiko indicates a stable error. + :raises TargetNotRespondingError: If the SSH session is lost unexpectedly. + + .. note:: This **will block the connection** until the command completes. + """ with _handle_paramiko_exceptions(command): return self._background(command, stdout, stderr, as_root) - def _background(self, command, stdout, stderr, as_root): + def _background(self, command: 'SubprocessCommand', stdout: int, + stderr: int, as_root: Optional[bool]) -> ParamikoBackgroundCommand: + """ + Internal helper for :meth:`background` that sets up the paramiko channel, + spawns the command, and wires up redirection threads. + + :param command: The shell command to execute in the background. + :param stdout: Destination for stdout (int file descriptor or special constant). + :param stderr: Destination for stderr (int file descriptor or special constant). + :param as_root: If True, run under sudo (if not already root). + :returns: The background command object. + :raises subprocess.CalledProcessError: If we cannot detect a valid PID or if the remote fails immediately. + :raises TargetStableError: If paramiko cannot open a session or other stable error occurs. + """ def make_init_kwargs(command): _stdout, _stderr, _command = redirect_streams(stdout, stderr, command) - _command = "printf '%s\n' $$; exec sh -c {}".format(quote(_command)) + _command = "printf '%s\n' $$; exec sh -c {}".format(quote(cast(str, _command))) channel = self._make_channel() - def executor(cmd, timeout): - channel.exec_command(cmd) - # Read are not buffered so we will always get the data as soon as - # they arrive - return ( - channel.makefile_stdin('w', 0), - channel.makefile(), - channel.makefile_stderr(), - ) + def executor(cmd, timeout) -> ChannelFiles: + if channel: + channel.exec_command(cmd) + # Read are not buffered so we will always get the data as soon as + # they arrive + return ( + channel.makefile_stdin('w', 0), + channel.makefile(), + channel.makefile_stderr(), + ) + else: + return cast(ChannelFiles, (None, None, None)) stdin, stdout_in, stderr_in = self._execute_command( _command, @@ -639,7 +980,7 @@ def executor(cmd, timeout): pid = stdout_in.readline() if not pid: _stderr = stderr_in.read() - if channel.exit_status_ready(): + if (channel is not None) and (channel.exit_status_ready()): ret = channel.recv_exit_status() else: ret = 126 @@ -751,15 +1092,40 @@ def callback(out_streams, name, chunk): make_init_kwargs=make_init_kwargs, ) - def _close(self): - logger.debug('Logging out {}@{}'.format(self.username, self.host)) - with _handle_paramiko_exceptions(): - self.client.close() + def _close(self) -> None: + """ + Close the SSH connection, releasing any underlying resources such as paramiko + sessions or sockets. After this call, the SshConnection is no longer usable. - def _execute_command(self, command, as_root, log, timeout, executor): + :raises TargetStableError: If a stable error occurs during disconnection. + """ + if logger: + logger.debug('Logging out {}@{}'.format(self.username, self.host)) + if '_handle_paramiko_exceptions' in globals() and (_handle_paramiko_exceptions is not None): + with _handle_paramiko_exceptions(): + if self.client: + self.client.close() + else: + if self.client: + self.client.close() # Fallback if _handle_paramiko_exceptions is missing + + def _execute_command(self, command: str, as_root: Optional[bool], + log: bool, timeout: Optional[int], + executor: Callable[..., ChannelFiles]) -> ChannelFiles: + """ + execute the command over the channel using the executor and return the channel in, out and err files + """ + def get_logger(log: bool) -> Callable[..., None]: + """ + get the logger + """ + if log: + return logger.debug + else: + return lambda msg: None # As we're already root, there is no need to use sudo. - log_debug = logger.debug if log else lambda msg: None - use_sudo = as_root and not self.connected_as_root + log_debug = get_logger(log) + use_sudo: Optional[bool] = as_root and not self.connected_as_root if use_sudo: if self._sudo_needs_password and not self.password: @@ -768,8 +1134,8 @@ def _execute_command(self, command, as_root, log, timeout, executor): command = self.sudo_cmd.format(quote(command)) log_debug(command) - streams = executor(command, timeout=timeout) - if self._sudo_needs_password: + streams: ChannelFiles = executor(command, timeout=timeout) + if self._sudo_needs_password and streams and self.password: stdin = streams[0] stdin.write(self.password + '\n') stdin.flush() @@ -779,10 +1145,16 @@ def _execute_command(self, command, as_root, log, timeout, executor): return streams - def _execute(self, command, timeout=None, as_root=False, strip_colors=True, log=True): + def _execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, + as_root: Optional[bool] = False, strip_colors: bool = True, + log: bool = True) -> Tuple[int, str]: + """ + execute the command and return the exit code and output + """ # Merge stderr into stdout since we are going without a TTY - command = '({}) 2>&1'.format(command) - + command = '({}) 2>&1'.format(cast(str, command)) + if self.client is None: + raise TargetStableError("client is None") stdin, stdout, stderr = self._execute_command( command, as_root=as_root, @@ -794,36 +1166,77 @@ def _execute(self, command, timeout=None, as_root=False, strip_colors=True, log= # Empty the stdout buffer of the command, allowing it to carry on to # completion - def callback(output_chunks, name, chunk): + def callback(output_chunks: List[bytes], name: str, chunk: bytes) -> List[bytes]: + """ + callback for _read_paramiko_streams + """ output_chunks.append(chunk) return output_chunks - select_timeout = 1 + select_timeout: float = 1 output_chunks, exit_code = _read_paramiko_streams(stdout, stderr, select_timeout, callback, []) + if output_chunks is None: + raise TargetStableError("output_chunks is None") # Join in one go to avoid O(N^2) concatenation - output = b''.join(output_chunks) - output = output.decode(sys.stdout.encoding or 'utf-8', 'replace') + output_b = b''.join(output_chunks) + output = output_b.decode(sys.stdout.encoding or 'utf-8', 'replace') return (exit_code, output) class TelnetConnection(SshConnectionBase): + """ + A connection using the Telnet protocol. In practice, this implements minimal + features such as command execution, but leverages local scp if needed for file + transfers (since Telnet does not provide a built-in file transfer mechanism). + + .. note:: Since Telnet protocol is does not support file transfer, scp is + used for that purpose. + + :param host: SSH host to which to connect + :param username: username for SSH login + :param password: password for the SSH connection + + .. note:: In order to user password-based authentication, + ``sshpass`` utility must be installed on the system. + + :param port: TCP port on which SSH server is listening on the remote device. + Omit to use the default port. + :param timeout: Timeout for the connection in seconds. If a connection + cannot be established within this time, an error will be + raised. + :param password_prompt: A string with the password prompt used by + ``sshpass``. Set this if your version of ``sshpass`` + uses something other than ``"[sudo] password"``. + :param original_prompt: A regex for the shell prompted presented in the Telnet + connection (the prompt will be reset to a + randomly-generated pattern for the duration of the + connection to reduce the possibility of clashes). + This parameter is ignored for SSH connections. + :param sudo_cmd: Template string for running commands with sudo privileges. + :param strict_host_check: Ignored for Telnet connections, included for interface consistency. + :param platform: A devlib Platform describing hardware or OS features. + + :raises TargetNotRespondingError: If the Telnet server is not reachable. + :raises HostError: If local scp usage fails for file transfers. + :raises TargetStableError: If login fails or commands cannot be executed. + """ - default_password_prompt = '[sudo] password' - max_cancel_attempts = 5 + default_password_prompt: str = '[sudo] password' + max_cancel_attempts: int = 5 # pylint: disable=unused-argument,super-init-not-called def __init__(self, - host, - username, - password=None, - port=None, - timeout=None, - password_prompt=None, - original_prompt=None, - sudo_cmd="sudo -- sh -c {}", - strict_host_check=True, - platform=None): + host: str, + username: str, + password: Optional[str] = None, + port: Optional[int] = None, + timeout: Optional[int] = None, + password_prompt: Optional[str] = None, + original_prompt: Optional[str] = None, + sudo_cmd: str = "sudo -- sh -c {}", + strict_host_check: Union[bool, str, os.PathLike] = True, + platform: Optional['Platform'] = None): super().__init__( host=host, @@ -843,12 +1256,18 @@ def __init__(self, logger.debug('Logging in {}@{}'.format(username, host)) timeout = timeout if timeout is not None else self.default_timeout - self.conn = telnet_get_shell(host, username, password, port, timeout, original_prompt) + self.conn: Optional['TelnetPxssh'] = telnet_get_shell(host, username, password, port, timeout, original_prompt) - def fmt_remote_path(self, path): + def fmt_remote_path(self, path: str) -> str: + """ + format remote path + """ return '{}@{}:{}'.format(self.username, self.host, path) - def _get_default_options(self): + def _get_default_options(self) -> Dict[str, str]: + """ + get defaults for stricthostcheck and known hosts + """ check = self.strict_host_check known_hosts = _resolve_known_hosts(check) return { @@ -856,13 +1275,19 @@ def _get_default_options(self): 'UserKnownHostsFile': str(known_hosts), } - def push(self, sources, dest, timeout=30): + def push(self, sources: List[str], dest: str, timeout: int = 30) -> None: + """ + push files to device through the connection + """ # Quote the destination as SCP would apply globbing too dest = self.fmt_remote_path(quote(dest)) paths = list(sources) + [dest] return self._scp(paths, timeout) - def pull(self, sources, dest, timeout=30): + def pull(self, sources: str, dest: str, timeout=30): + """ + pull files from device + """ # First level of escaping for the remote shell sources = ' '.join(map(quote, sources)) # All the sources are merged into one scp parameter @@ -870,22 +1295,22 @@ def pull(self, sources, dest, timeout=30): paths = [sources, dest] self._scp(paths, timeout) - def _scp(self, paths, timeout=30): + def _scp(self, paths: List[str], timeout=30): # NOTE: the version of scp in Ubuntu 12.04 occasionally (and bizarrely) # fails to connect to a device if port is explicitly specified using -P # option, even if it is the default port, 22. To minimize this problem, # only specify -P for scp if the port is *not* the default. - port_string = '-P {}'.format(quote(str(self.port))) if (self.port and self.port != 22) else '' - keyfile_string = '-i {}'.format(quote(self.keyfile)) if self.keyfile else '' - options = " ".join(["-o {}={}".format(key, val) - for key, val in self.options.items()]) - paths = ' '.join(map(quote, paths)) - command = '{} {} -r {} {} {}'.format(_SSH_ENV.get_path('scp'), - options, - keyfile_string, - port_string, - paths) - command_redacted = command + port_string: str = '-P {}'.format(quote(str(self.port))) if (self.port and self.port != 22) else '' + keyfile_string: str = '-i {}'.format(quote(self.keyfile)) if self.keyfile else '' + options: str = " ".join(["-o {}={}".format(key, val) + for key, val in self.options.items()]) + paths_s: str = ' '.join(map(quote, paths)) + command: str = '{} {} -r {} {} {}'.format(_SSH_ENV.get_path('scp'), + options, + keyfile_string, + port_string, + paths_s) + command_redacted: str = command logger.debug(command) if self.password: command, command_redacted = _give_password(self.password, command) @@ -897,21 +1322,20 @@ def _scp(self, paths, timeout=30): except TimeoutError as e: raise TimeoutError(command_redacted, e.output) - - def execute(self, command, timeout=None, check_exit_code=True, - as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument + def execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, check_exit_code: bool = True, + as_root: Optional[bool] = False, strip_colors: bool = True, will_succeed: bool = False) -> str: # pylint: disable=unused-argument if command == '': # Empty command is valid but the __devlib_ec stuff below will # produce a syntax error with bash. Treat as a special case. return '' try: with self.lock: - _command = '({}); __devlib_ec=$?; echo; echo $__devlib_ec'.format(command) + _command = '({}); __devlib_ec=$?; echo; echo $__devlib_ec'.format(cast(str, command)) full_output = self._execute_and_wait_for_prompt(_command, timeout, as_root, strip_colors) split_output = full_output.rsplit('\r\n', 2) try: output, exit_code_text, _ = split_output - except ValueError as e: + except ValueError: raise TargetStableError( "cannot split reply (target misconfiguration?):\n'{}'".format(full_output)) if check_exit_code: @@ -919,8 +1343,8 @@ def execute(self, command, timeout=None, check_exit_code=True, exit_code = int(exit_code_text) except (ValueError, IndexError): raise ValueError( - 'Could not get exit code for "{}",\ngot: "{}"'\ - .format(command, exit_code_text)) + 'Could not get exit code for "{}",\ngot: "{}"' + .format(cast(str, command), exit_code_text)) if exit_code: cls = TargetTransientCalledProcessError if will_succeed else TargetStableCalledProcessError raise cls( @@ -940,42 +1364,56 @@ def execute(self, command, timeout=None, check_exit_code=True, else: raise - def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + def background(self, command: 'SubprocessCommand', stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, as_root: Optional[bool] = False) -> 'Popen': try: port_string = '-p {}'.format(self.port) if self.port else '' keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else '' if as_root and not self.connected_as_root: - command = self.sudo_cmd.format(command) - options = " ".join([ "-o {}={}".format(key,val) - for key,val in self.options.items()]) - command = '{} {} {} {} {}@{} {}'.format(_SSH_ENV.get_path('ssh'), - options, - keyfile_string, - port_string, - self.username, - self.host, - command) + commandstr = self.sudo_cmd.format(command) + options = " ".join(["-o {}={}".format(key, val) + for key, val in self.options.items()]) + commandstr = '{} {} {} {} {}@{} {}'.format(_SSH_ENV.get_path('ssh'), + options, + keyfile_string, + port_string, + self.username, + self.host, + commandstr) logger.debug(command) if self.password: - command, _ = _give_password(self.password, command) + command, _ = _give_password(self.password, cast(str, command)) return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True) except EOF: raise TargetNotRespondingError('Connection lost.') - def _close(self): + def _close(self) -> None: + """ + Close the connection to the device. The :class:`Connection` object should not + be used after this method is called. There is no way to reopen a previously + closed connection, a new connection object should be created instead. + """ logger.debug('Logging out {}@{}'.format(self.username, self.host)) try: - self.conn.logout() + if self.conn: + self.conn.logout() except: logger.debug('Connection lost.') - self.conn.close(force=True) + if self.conn: + self.conn.close(force=True) - def cancel_running_command(self): + def cancel_running_command(self) -> bool: + """ + Cancel a running command (previously started with :func:`background`) and free up the connection. + It is valid to call this if the command has already terminated (or if no + command was issued), in which case this is a no-op. + """ + # FIXME - other instances of cancel_running_command is just returning None. should this also be changed to do the same? # simulate impatiently hitting ^C until command prompt appears logger.debug('Sending ^C') for _ in range(self.max_cancel_attempts): self._sendline(chr(3)) - if self.conn.prompt(0.1): + if self.conn and self.conn.prompt(0.1): return True return False @@ -985,16 +1423,23 @@ def wait_for_device(self, timeout=30): def reboot_bootloader(self, timeout=30): raise NotImplementedError() - def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, strip_colors=True, log=True): + def _execute_and_wait_for_prompt(self, command: 'SubprocessCommand', timeout: Optional[int] = None, + as_root: Optional[bool] = False, strip_colors: bool = True, + log: bool = True) -> str: + """ + execute command and wait for prompt + """ + if not self.conn: + raise TargetStableError("conn is None") self.conn.prompt(0.1) # clear an existing prompt if there is one. if as_root and self.connected_as_root: # As we're already root, there is no need to use sudo. as_root = False if as_root: - command = self.sudo_cmd.format(quote(command)) + command = self.sudo_cmd.format(quote(cast(str, command))) if log: logger.debug(command) - self._sendline(command) + self._sendline(cast(str, command)) if self.password: index = self.conn.expect_exact([self.password_prompt, TIMEOUT], timeout=0.5) if index == 0: @@ -1002,8 +1447,10 @@ def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, str else: # not as_root if log: logger.debug(command) - self._sendline(command) + self._sendline(cast(str, command)) timed_out = self._wait_for_prompt(timeout) + if self.conn.before is None: + raise TargetStableError("conn.before is None") output = process_backspaces(self.conn.before.decode(sys.stdout.encoding or 'utf-8', 'replace')) if timed_out: @@ -1013,28 +1460,45 @@ def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, str output = strip_bash_colors(output) return output - def _wait_for_prompt(self, timeout=None): - if timeout: - return not self.conn.prompt(timeout) - else: # cannot timeout; wait forever - while not self.conn.prompt(1): - pass + def _wait_for_prompt(self, timeout: Optional[int] = None) -> bool: + """ + wait for prompt + """ + if self.conn: + if timeout: + return not self.conn.prompt(timeout) + else: # cannot timeout; wait forever + while not self.conn.prompt(1): + pass + return False + else: return False - def _sendline(self, command): + def _sendline(self, command: str) -> None: + """ + send a line of string + """ # Workaround for https://github.com/pexpect/pexpect/issues/552 if len(command) == self._get_window_size()[1] - self._get_prompt_length(): command += ' ' - self.conn.sendline(command) + if self.conn: + self.conn.sendline(command) @memoized - def _get_prompt_length(self): + def _get_prompt_length(self) -> int: + """ + get the length of the prompt + """ + if not self.conn: + raise TargetStableError("conn is none") self.conn.sendline() self.conn.prompt() - return len(self.conn.after) + return len(cast(Sized, self.conn.after)) @memoized - def _get_window_size(self): + def _get_window_size(self) -> Tuple[int, int]: + if not self.conn: + raise TargetStableError("conn is none") return self.conn.getwinsize() @@ -1056,9 +1520,9 @@ def __init__(self, host_system = socket.gethostname() if host_system != host: raise TargetStableError("Gem5Connection can only connect to gem5 " - "simulations on your current host {}, which " - "differs from the one given {}!" - .format(host_system, host)) + "simulations on your current host {}, which " + "differs from the one given {}!" + .format(host_system, host)) if username is not None and username != 'root': raise ValueError('User should be root in gem5!') if password is not None and password != '': @@ -1075,7 +1539,7 @@ def __init__(self, if timeout is not None: if timeout > self.default_timeout: logger.info('Overwriting the default timeout of gem5 ({})' - ' to {}'.format(self.default_timeout, timeout)) + ' to {}'.format(self.default_timeout, timeout)) self.default_timeout = timeout else: logger.info('Ignoring the given timeout --> gem5 needs longer timeouts') @@ -1092,7 +1556,7 @@ def __init__(self, # Lock file to prevent multiple connections to same gem5 simulation # (gem5 does not allow this) self.lock_directory = '/tmp/' - self.lock_file_name = None # Will be set once connected to gem5 + self.lock_file_name = None # Will be set once connected to gem5 # These parameters will be set by either the method to connect to the # gem5 platform or directly to the gem5 simulation @@ -1149,7 +1613,7 @@ def push(self, sources, dest, timeout=None): self._gem5_shell("ls -al {}".format(quote(self.gem5_input_dir))) logger.debug("Push complete.") - def pull(self, sources, dest, timeout=0): #pylint: disable=unused-argument + def pull(self, sources, dest, timeout=0): # pylint: disable=unused-argument """ Pull a file from the gem5 device using m5 writefile @@ -1175,30 +1639,32 @@ def pull(self, sources, dest, timeout=0): #pylint: disable=unused-argument # error if the file was not where we expected it to be. if os.path.isabs(source): if os.path.dirname(source) != self.execute('pwd', - check_exit_code=False): + check_exit_code=False): self._gem5_shell("cat {} > {}".format(quote(filename), - quote(dest_file))) + quote(dest_file))) self._gem5_shell("sync") self._gem5_shell("ls -la {}".format(dest_file)) logger.debug('Finished the copy in the simulator') self._gem5_util("writefile {}".format(dest_file)) if 'cpu' not in filename: - while not os.path.exists(os.path.join(self.gem5_out_dir, - dest_file)): - time.sleep(1) + if self.gem5_out_dir: + while not os.path.exists(os.path.join(self.gem5_out_dir, + dest_file)): + time.sleep(1) # Perform the local move if os.path.exists(os.path.join(dest, dest_file)): logger.warning( - 'Destination file {} already exists!'\ - .format(dest_file)) + 'Destination file {} already exists!' + .format(dest_file)) else: - shutil.move(os.path.join(self.gem5_out_dir, dest_file), dest) + if self.gem5_out_dir: + shutil.move(os.path.join(self.gem5_out_dir, dest_file), dest) logger.debug("Pull complete.") def execute(self, command, timeout=1000, check_exit_code=True, - as_root=False, strip_colors=True, will_succeed=False): + as_root: Optional[bool] = False, strip_colors=True, will_succeed=False): """ Execute a command on the gem5 platform """ @@ -1228,7 +1694,7 @@ def background(self, command, stdout=subprocess.PIPE, self._check_ready() # Create the logfile for stderr/stdout redirection - command_name = command.split(' ')[0].split('/')[-1] + command_name = cast(str, command).split(' ')[0].split('/')[-1] redirection_file = 'BACKGROUND_{}.log'.format(command_name) trial = 0 while os.path.isfile(redirection_file): @@ -1257,27 +1723,31 @@ def _close(self): # the end of a simulation! self._unmount_virtio() self._gem5_util("exit") - self.gem5simulation.wait() + if self.gem5simulation: + self.gem5simulation.wait() except EOF: pass gem5_logger.info("Removing the temporary directory") try: - shutil.rmtree(self.gem5_interact_dir) + if self.gem5_interact_dir: + shutil.rmtree(self.gem5_interact_dir) except OSError: gem5_logger.warning("Failed to remove the temporary directory!") # Delete the lock file - os.remove(self.lock_file_name) + if self.lock_file_name: + os.remove(self.lock_file_name) def wait_for_device(self, timeout=30): """ Wait for Gem5 to be ready for interation with a timeout. """ - for _ in attempts(timeout): + # FIXME - attempts function not defined. not sure if it is a library function or this is the right intention + for _ in attempts(timeout): # type:ignore if self.ready: return time.sleep(1) - raise TimeoutError('Gem5 is not ready for interaction') + raise TimeoutError('Gem5 is not ready for interaction', '') def reboot_bootloader(self, timeout=30): raise NotImplementedError() @@ -1308,7 +1778,7 @@ def _gem5_EOF_handler(self, gem5_simulation, gem5_out_dir, err): # This function connects to the gem5 simulation # pylint: disable=too-many-statements def connect_gem5(self, port, gem5_simulation, gem5_interact_dir, - gem5_out_dir): + gem5_out_dir): """ Connect to the telnet port of the gem5 simulation. @@ -1329,8 +1799,8 @@ def connect_gem5(self, port, gem5_simulation, gem5_interact_dir, if os.path.isfile(lock_file_name): # There is already a connection to this gem5 simulation raise TargetStableError('There is already a connection to the gem5 ' - 'simulation using port {} on {}!' - .format(port, host)) + 'simulation using port {} on {}!' + .format(port, host)) # Connect to the gem5 telnet port. Use a short timeout here. attempts = 0 @@ -1353,7 +1823,7 @@ def connect_gem5(self, port, gem5_simulation, gem5_interact_dir, # Create the lock file self.lock_file_name = lock_file_name - open(self.lock_file_name, 'w').close() # Similar to touch + open(self.lock_file_name, 'w').close() # Similar to touch gem5_logger.info("Created lock file {} to prevent reconnecting to " "same simulation".format(self.lock_file_name)) @@ -1409,6 +1879,8 @@ def _login_to_device(self): def _find_prompt(self): prompt = r'\[PEXPECT\][\\\$\#]+ ' synced = False + if self.conn is None: + raise TargetStableError("Conn is None") while not synced: self.conn.send('\n') i = self.conn.expect([prompt, self.conn.UNIQUE_PROMPT, r'[\$\#] '], timeout=self.default_timeout) @@ -1418,10 +1890,11 @@ def _find_prompt(self): prompt = self.conn.UNIQUE_PROMPT synced = True else: - prompt = re.sub(r'\$', r'\\\$', self.conn.before.strip() + self.conn.after.strip()) - prompt = re.sub(r'\#', r'\\\#', prompt) - prompt = re.sub(r'\[', r'\[', prompt) - prompt = re.sub(r'\]', r'\]', prompt) + if self.conn.before and self.conn.after: + prompt = re.sub(r'\$', r'\\\$', self.conn.before.strip() + cast(bytes, self.conn.after).strip()) + prompt = re.sub(r'\#', r'\\\#', prompt) + prompt = re.sub(r'\[', r'\[', prompt) + prompt = re.sub(r'\]', r'\]', prompt) self.conn.PROMPT = prompt @@ -1434,10 +1907,11 @@ def _sync_gem5_shell(self): both of these. """ gem5_logger.debug("Sending Sync") - self.conn.send("echo \\*\\*sync\\*\\*\n") - self.conn.expect(r"\*\*sync\*\*", timeout=self.default_timeout) - self.conn.expect([self.conn.UNIQUE_PROMPT, self.conn.PROMPT], timeout=self.default_timeout) - self.conn.expect([self.conn.UNIQUE_PROMPT, self.conn.PROMPT], timeout=self.default_timeout) + if self.conn: + self.conn.send("echo \\*\\*sync\\*\\*\n") + self.conn.expect(r"\*\*sync\*\*", timeout=self.default_timeout) + self.conn.expect([self.conn.UNIQUE_PROMPT, self.conn.PROMPT], timeout=self.default_timeout) + self.conn.expect([self.conn.UNIQUE_PROMPT, self.conn.PROMPT], timeout=self.default_timeout) def _gem5_util(self, command): """ Execute a gem5 utility command using the m5 binary on the device """ @@ -1445,7 +1919,8 @@ def _gem5_util(self, command): raise TargetStableError('Path to m5 binary on simulated system is not set!') self._gem5_shell('{} {}'.format(self.m5_path, command)) - def _gem5_shell(self, command, as_root=False, timeout=None, check_exit_code=True, sync=True, will_succeed=False): # pylint: disable=R0912 + def _gem5_shell(self, command, as_root: Optional[bool] = False, + timeout=None, check_exit_code=True, sync=True, will_succeed=False): # pylint: disable=R0912 """ Execute a command in the gem5 shell @@ -1465,7 +1940,8 @@ def _gem5_shell(self, command, as_root=False, timeout=None, check_exit_code=True if as_root: command = 'echo {} | su'.format(quote(command)) - + if self.conn is None: + raise TargetStableError("Conn is None") # Send the actual command self.conn.send("{}\n".format(command)) @@ -1475,7 +1951,7 @@ def _gem5_shell(self, command, as_root=False, timeout=None, check_exit_code=True command_index = -1 while command_index == -1: if self.conn.prompt(): - output = re.sub(r' \r([^\n])', r'\1', self.conn.before) + output = re.sub(r' \r([^\n])', r'\1', self.conn.before or '') output = re.sub(r'[\b]', r'', output) # Deal with line wrapping output = re.sub(r'[\r].+?<', r'', output) @@ -1486,7 +1962,7 @@ def _gem5_shell(self, command, as_root=False, timeout=None, check_exit_code=True # warn, and return the whole output. if command_index == -1: gem5_logger.warning("gem5_shell: Unable to match command in " - "command output. Expect parsing errors!") + "command output. Expect parsing errors!") command_index = 0 output = output[command_index + len(command):].strip() @@ -1506,15 +1982,15 @@ def _gem5_shell(self, command, as_root=False, timeout=None, check_exit_code=True if check_exit_code: exit_code_text = self._gem5_shell('echo $?', as_root=as_root, - timeout=timeout, check_exit_code=False, - sync=False) + timeout=timeout, check_exit_code=False, + sync=False) try: exit_code = int(exit_code_text.split()[0]) except (ValueError, IndexError): raise ValueError('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text)) else: if exit_code: - cls = TragetTransientCalledProcessError if will_succeed else TargetStableCalledProcessError + cls = TargetTransientCalledProcessError if will_succeed else TargetStableCalledProcessError raise cls( exit_code, command, @@ -1594,20 +2070,21 @@ def _login_to_device(self): gem5_logger.info("Trying to log in to gem5 device") login_prompt = ['login:', 'AEL login:', 'username:', 'aarch64-gem5 login:'] login_password_prompt = ['password:'] + if self.conn is None: + raise TargetStableError("Conn is None") # Wait for the login prompt prompt = login_prompt + [self.conn.UNIQUE_PROMPT] - i = self.conn.expect(prompt, timeout=10) + i = self.conn.expect(cast(Pattern[str], prompt), timeout=10) # Check if we are already at a prompt, or if we need to log in. if i < len(prompt) - 1: self.conn.sendline("{}".format(self.username)) password_prompt = login_password_prompt + [r'# ', self.conn.UNIQUE_PROMPT] - j = self.conn.expect(password_prompt, timeout=self.default_timeout) + j = self.conn.expect(cast(Pattern[str], password_prompt), timeout=self.default_timeout) if j < len(password_prompt) - 2: self.conn.sendline("{}".format(self.password)) self.conn.expect([r'# ', self.conn.UNIQUE_PROMPT], timeout=self.default_timeout) - class AndroidGem5Connection(Gem5Connection): def _wait_for_boot(self): @@ -1637,7 +2114,17 @@ def _wait_for_boot(self): gem5_logger.info("Android booted") -def _give_password(password, command): +def _give_password(password: str, command: str) -> Tuple[str, str]: + """ + Insert a password into an ``sshpass``-based command to allow non-interactive + authentication. + + :param password: The password to embed in the command. + :param command: The original shell command that invokes ``sshpass``. + :returns: A tuple of (modified_command, redacted_command). The first string is + safe to execute, while the second omits the password for logging. + :raises ValueError: If the command cannot be adjusted or if ``sshpass`` is unavailable. + """ sshpass = _SSH_ENV.get_path('sshpass') if sshpass: pass_template = "{} -p {} " @@ -1648,8 +2135,11 @@ def _give_password(password, command): raise HostError('Must have sshpass installed on the host in order to use password-based auth.') -def process_backspaces(text): - chars = [] +def process_backspaces(text: str) -> str: + """ + process backspace in the command + """ + chars: List[str] = [] for c in text: if c == chr(8) and chars: # backspace chars.pop() diff --git a/devlib/utils/types.py b/devlib/utils/types.py index d7c8864b0..55f0fe5a0 100644 --- a/devlib/utils/types.py +++ b/devlib/utils/types.py @@ -1,4 +1,4 @@ -# Copyright 2014-2018 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,9 +30,9 @@ import sys from functools import total_ordering -from past.builtins import basestring from devlib.utils.misc import isiterable, to_identifier, ranges_to_list, list_to_mask +from typing import List, Union def identifier(text): @@ -49,7 +49,7 @@ def boolean(value): """ false_strings = ['', '0', 'n', 'no', 'off'] - if isinstance(value, basestring): + if isinstance(value, str): value = value.lower() if value in false_strings or 'false'.startswith(value): return False @@ -58,7 +58,7 @@ def boolean(value): def integer(value): """Handles conversions for string respresentations of binary, octal and hex.""" - if isinstance(value, basestring): + if isinstance(value, str): return int(value, 0) else: return int(value) @@ -74,7 +74,7 @@ def numeric(value): if isinstance(value, int): return value - if isinstance(value, basestring): + if isinstance(value, str): value = value.strip() if value.endswith('%'): try: @@ -102,17 +102,17 @@ class caseless_string(str): """ def __eq__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): other = other.lower() return self.lower() == other def __ne__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): other = other.lower() return self.lower() != other def __lt__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): other = other.lower() return self.lower() < other @@ -123,11 +123,14 @@ def format(self, *args, **kwargs): return caseless_string(super(caseless_string, self).format(*args, **kwargs)) -def bitmask(value): - if isinstance(value, basestring): +def bitmask(value: Union[int, List[int], str]) -> int: + if isinstance(value, str): value = ranges_to_list(value) if isiterable(value): - value = list_to_mask(value) + if isinstance(value, list): + value = list_to_mask(value) + else: + raise TypeError("Expected a list of integers") if not isinstance(value, int): raise ValueError(value) return value diff --git a/devlib/utils/uboot.py b/devlib/utils/uboot.py index 1e0169770..948cfeb98 100644 --- a/devlib/utils/uboot.py +++ b/devlib/utils/uboot.py @@ -15,12 +15,11 @@ # import re import time -import logging from devlib.utils.serial_port import TIMEOUT +from devlib.utils.misc import get_logger - -logger = logging.getLogger('U-Boot') +logger = get_logger('U-Boot') class UbootMenu(object): diff --git a/devlib/utils/uefi.py b/devlib/utils/uefi.py index f56991672..b1f5b9c0d 100644 --- a/devlib/utils/uefi.py +++ b/devlib/utils/uefi.py @@ -16,16 +16,15 @@ import re import time -import logging -from copy import copy -from past.builtins import basestring +from copy import copy from devlib.utils.serial_port import write_characters, TIMEOUT from devlib.utils.types import boolean +from devlib.utils.misc import get_logger -logger = logging.getLogger('UEFI') +logger = get_logger('UEFI') class UefiConfig(object): @@ -134,7 +133,7 @@ def select(self, option, timeout=default_timeout): long-running operation. """ - if isinstance(option, basestring): + if isinstance(option, str): option = self.get_option_index(option, timeout) self.enter(option) diff --git a/devlib/utils/version.py b/devlib/utils/version.py index 2409b6783..b299b7385 100644 --- a/devlib/utils/version.py +++ b/devlib/utils/version.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,16 +15,21 @@ import os import sys -from collections import namedtuple from subprocess import Popen, PIPE +from typing import NamedTuple, Optional -VersionTuple = namedtuple('Version', ['major', 'minor', 'revision', 'dev']) +class Version(NamedTuple): + major: int + minor: int + revision: int + dev: str -version = VersionTuple(1, 4, 0, 'dev3') +version = Version(1, 4, 0, 'dev3') -def get_devlib_version(): + +def get_devlib_version() -> str: version_string = '{}.{}.{}'.format( version.major, version.minor, version.revision) if version.dev: @@ -32,7 +37,7 @@ def get_devlib_version(): return version_string -def get_commit(): +def get_commit() -> Optional[str]: try: p = Popen(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__), stdout=PIPE, stderr=PIPE) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..4077efd50 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,6 @@ +[mypy] +ignore_missing_imports = True +python_version = 3.10 + +[mypy-numpy.*] +ignore_errors = True \ No newline at end of file diff --git a/py.typed b/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/setup.py b/setup.py index cba25a26b..9686c6233 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2013-2015 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -97,17 +97,18 @@ def _load_path(filepath): 'python-dateutil', # converting between UTC and local time. 'pexpect>=3.3', # Send/recieve to/from device 'pyserial', # Serial port interface - 'paramiko', # SSH connection - 'scp', # SSH connection file transfers + 'paramiko', # SSH connection + 'scp', # SSH connection file transfers 'wrapt', # Basic for construction of decorator functions 'numpy', 'pandas', 'pytest', - 'lxml', # More robust xml parsing - 'nest_asyncio', # Allows running nested asyncio loops - 'greenlet', # Allows running nested asyncio loops - 'future', # for the "past" Python package - 'ruamel.yaml >= 0.15.72', # YAML formatted config parsing + 'lxml', # More robust xml parsing + 'nest_asyncio', # Allows running nested asyncio loops + 'greenlet', # Allows running nested asyncio loops + 'future', # for the "past" Python package + 'ruamel.yaml >= 0.15.72', # YAML formatted config parsing + 'typing_extensions' ], extras_require={ 'daq': ['daqpower>=2'], @@ -115,7 +116,7 @@ def _load_path(filepath): 'monsoon': ['python-gflags'], 'acme': ['pandas', 'numpy'], 'dev': [ - 'uvloop', # Test async features under uvloop + 'uvloop', # Test async features under uvloop ] }, # https://pypi.python.org/pypi?%3Aaction=list_classifiers @@ -142,7 +143,6 @@ def initialize_options(self): orig_sdist.initialize_options(self) self.strip_commit = False - def run(self): if self.strip_commit: self.distribution.get_version = lambda : __version__.split('+')[0] diff --git a/tests/test_target.py b/tests/test_target.py index 2d811321f..6665e7343 100644 --- a/tests/test_target.py +++ b/tests/test_target.py @@ -1,5 +1,5 @@ # -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,27 +23,26 @@ $ python -m pytest --log-cli-level DEBUG test_target.py """ -import logging import os +from typing import Optional + import pytest from devlib import AndroidTarget, ChromeOsTarget, LinuxTarget, LocalLinuxTarget from devlib._target_runner import NOPTargetRunner, QEMUTargetRunner from devlib.utils.android import AdbConnection -from devlib.utils.misc import load_struct_from_yaml +from devlib.utils.misc import load_struct_from_yaml, get_logger -logger = logging.getLogger('test_target') +logger = get_logger('test_target') -def get_class_object(name): +def get_class_object(name: str) -> Optional[object]: """ Get associated class object from string formatted class name :param name: Class name - :type name: str :return: Class object - :rtype: object or None """ if globals().get(name) is None: return None