Skip to content

Commit

Permalink
Generic cleanup + add py38 annotations (#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelboulton authored Jan 27, 2024
1 parent 509a0c5 commit d35a5ce
Show file tree
Hide file tree
Showing 43 changed files with 515 additions and 397 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
hooks:
- id: pycln
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.1.11"
rev: "v0.1.14"
hooks:
- id: ruff-format
- id: ruff
Expand Down
4 changes: 2 additions & 2 deletions scripts/smoke.bash
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ pre-commit run ruff --all-files || true
pre-commit run ruff-format --all-files || true

tox --parallel -c tox.ini \
-e py3check
-e py3mypy

tox --parallel -c tox.ini \
-e py3mypy
-e py3check

tox --parallel -c tox.ini \
-e py3
Expand Down
104 changes: 63 additions & 41 deletions tavern/_core/dict_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import os
import re
import string
from typing import Any, Dict, List, Mapping, Union
import typing
from collections.abc import Collection
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union

import box
import jmespath
Expand All @@ -22,10 +24,10 @@
from .formatted_str import FormattedString
from .strict_util import StrictSetting, StrictSettingKinds, extract_strict_setting

logger = logging.getLogger(__name__)
logger: logging.Logger = logging.getLogger(__name__)


def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str:
def _check_and_format_values(to_format: str, box_vars: Mapping[str, Any]) -> str:
formatter = string.Formatter()
would_format = formatter.parse(to_format)

Expand Down Expand Up @@ -55,7 +57,7 @@ def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str:
return to_format.format(**box_vars)


def _attempt_find_include(to_format: str, box_vars: box.Box):
def _attempt_find_include(to_format: str, box_vars: box.Box) -> Optional[str]:
formatter = string.Formatter()
would_format = list(formatter.parse(to_format))

Expand Down Expand Up @@ -89,32 +91,39 @@ def _attempt_find_include(to_format: str, box_vars: box.Box):

would_replace = formatter.get_field(field_name, [], box_vars)[0]

return formatter.convert_field(would_replace, conversion) # type: ignore
if conversion is None:
return would_replace

return formatter.convert_field(would_replace, conversion)


T = typing.TypeVar("T", str, Dict, List, Tuple)


def format_keys(
val,
variables: Mapping,
val: T,
variables: Union[Mapping, Box],
*,
no_double_format: bool = True,
dangerously_ignore_string_format_errors: bool = False,
):
) -> T:
"""recursively format a dictionary with the given values
Args:
val: Input dictionary to format
val: Input thing to format
variables: Dictionary of keys to format it with
no_double_format: Whether to use the 'inner formatted string' class to avoid double formatting
This is required if passing something via pytest-xdist, such as markers:
https://github.com/taverntesting/tavern/issues/431
dangerously_ignore_string_format_errors: whether to ignore any string formatting errors. This will result
in broken output, only use for debugging purposes.
Raises:
MissingFormatError: if a format variable was not found in variables
Returns:
recursively formatted values
"""
formatted = val

format_keys_ = functools.partial(
format_keys,
dangerously_ignore_string_format_errors=dangerously_ignore_string_format_errors,
Expand All @@ -126,15 +135,15 @@ def format_keys(
box_vars = variables

if isinstance(val, dict):
formatted = {}
# formatted = {key: format_keys(val[key], box_vars) for key in val}
for key in val:
formatted[key] = format_keys_(val[key], box_vars)
elif isinstance(val, (list, tuple)):
formatted = [format_keys_(item, box_vars) for item in val] # type: ignore
elif isinstance(formatted, FormattedString):
logger.debug("Already formatted %s, not double-formatting", formatted)
return {key: format_keys_(val[key], box_vars) for key in val}
elif isinstance(val, tuple):
return tuple(format_keys_(item, box_vars) for item in val)
elif isinstance(val, list):
return [format_keys_(item, box_vars) for item in val]
elif isinstance(val, FormattedString):
logger.debug("Already formatted %s, not double-formatting", val)
elif isinstance(val, str):
formatted = val
try:
formatted = _check_and_format_values(val, box_vars)
except exceptions.MissingFormatError:
Expand All @@ -143,20 +152,22 @@ def format_keys(

if no_double_format:
formatted = FormattedString(formatted) # type: ignore

return formatted
elif isinstance(val, TypeConvertToken):
logger.debug("Got type convert token '%s'", val)
if isinstance(val, ForceIncludeToken):
formatted = _attempt_find_include(val.value, box_vars)
return _attempt_find_include(val.value, box_vars)
else:
value = format_keys_(val.value, box_vars)
formatted = val.constructor(value)
return val.constructor(value)
else:
logger.debug("Not formatting something of type '%s'", type(formatted))
logger.debug("Not formatting something of type '%s'", type(val))

return formatted
return val


def recurse_access_key(data, query: str):
def recurse_access_key(data: Union[List, Mapping], query: str) -> Any:
"""
Search for something in the given data using the given query.
Expand All @@ -168,11 +179,14 @@ def recurse_access_key(data, query: str):
'c'
Args:
data (dict, list): Data to search in
query (str): Query to run
data: Data to search in
query: Query to run
Raises:
JMESError: if there was an error parsing the query
Returns:
object: Whatever was found by the search
Whatever was found by the search
"""

try:
Expand All @@ -195,7 +209,9 @@ def recurse_access_key(data, query: str):
return from_jmespath


def _deprecated_recurse_access_key(current_val, keys):
def _deprecated_recurse_access_key(
current_val: Union[List, Mapping], keys: List
) -> Any:
"""Given a list of keys and a dictionary, recursively access the dicionary
using the keys until we find the key its looking for
Expand All @@ -209,15 +225,15 @@ def _deprecated_recurse_access_key(current_val, keys):
'c'
Args:
current_val (dict): current dictionary we have recursed into
keys (list): list of str/int of subkeys
current_val: current dictionary we have recursed into
keys: list of str/int of subkeys
Raises:
IndexError: list index not found in data
KeyError: dict key not found in data
Returns:
str or dict: value of subkey in dict
value of subkey in dict
"""
logger.debug("Recursively searching for '%s' in '%s'", keys, current_val)

Expand Down Expand Up @@ -266,12 +282,12 @@ def deep_dict_merge(initial_dct: Dict, merge_dct: Mapping) -> dict:
return dct


def check_expected_keys(expected, actual) -> None:
def check_expected_keys(expected: Collection, actual: Collection) -> None:
"""Check that a set of expected keys is a superset of the actual keys
Args:
expected (list, set, dict): keys we expect
actual (list, set, dict): keys we have got from the input
expected: keys we expect
actual: keys we have got from the input
Raises:
UnexpectedKeysError: If not actual <= expected
Expand All @@ -289,7 +305,7 @@ def check_expected_keys(expected, actual) -> None:
raise exceptions.UnexpectedKeysError(msg)


def yield_keyvals(block):
def yield_keyvals(block: Union[List, Dict]) -> Iterator[Tuple[List, str, str]]:
"""Return indexes, keys and expected values for matching recursive keys
Given a list or dict, return a 3-tuple of the 'split' key (key split on
Expand Down Expand Up @@ -321,10 +337,10 @@ def yield_keyvals(block):
(['2'], '2', 'c')
Args:
block (dict, list): input matches
block: input matches
Yields:
(list, str, str): key split on dots, key, expected value
iterable of (key split on dots, key, expected value)
"""
if isinstance(block, dict):
for joined_key, expected_val in block.items():
Expand All @@ -336,9 +352,12 @@ def yield_keyvals(block):
yield [sidx], sidx, val


Checked = typing.TypeVar("Checked", Dict, Collection, str)


def check_keys_match_recursive(
expected_val: Any,
actual_val: Any,
expected_val: Checked,
actual_val: Checked,
keys: List[Union[str, int]],
strict: StrictSettingKinds = True,
) -> None:
Expand Down Expand Up @@ -443,8 +462,8 @@ def _format_err(which):
raise exceptions.KeyMismatchError(msg) from e

if isinstance(expected_val, dict):
akeys = set(actual_val.keys())
ekeys = set(expected_val.keys())
akeys = set(actual_val.keys()) # type:ignore

if akeys != ekeys:
extra_actual_keys = akeys - ekeys
Expand Down Expand Up @@ -481,7 +500,10 @@ def _format_err(which):
for key in to_recurse:
try:
check_keys_match_recursive(
expected_val[key], actual_val[key], keys + [key], strict
expected_val[key],
actual_val[key], # type:ignore
keys + [key],
strict,
)
except KeyError:
logger.debug(
Expand Down
21 changes: 12 additions & 9 deletions tavern/_core/extfunctions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import importlib
import logging
from typing import Any, List, Mapping, Optional
from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple

from tavern._core import exceptions

Expand All @@ -16,7 +16,7 @@ def is_ext_function(block: Any) -> bool:
block: Any object
Returns:
bool: If it is an ext function style dict
If it is an ext function style dict
"""
return isinstance(block, dict) and block.get("$ext", None) is not None

Expand All @@ -29,17 +29,20 @@ def get_pykwalify_logger(module: Optional[str]) -> logging.Logger:
trying to get the root logger which won't log correctly
Args:
module (string): name of module to get logger for
module: name of module to get logger for
Returns:
logger for given module
"""
return logging.getLogger(module)


def _getlogger() -> logging.Logger:
"""Get logger for this module"""
return get_pykwalify_logger("tavern._core.extfunctions")


def import_ext_function(entrypoint: str):
def import_ext_function(entrypoint: str) -> Callable:
"""Given a function name in the form of a setuptools entry point, try to
dynamically load and return it
Expand All @@ -48,7 +51,7 @@ def import_ext_function(entrypoint: str):
module.submodule:function
Returns:
function: function loaded from entrypoint
function loaded from entrypoint
Raises:
InvalidExtFunctionError: If the module or function did not exist
Expand Down Expand Up @@ -79,7 +82,7 @@ def import_ext_function(entrypoint: str):
return function


def get_wrapped_response_function(ext: Mapping):
def get_wrapped_response_function(ext: Mapping) -> Callable:
"""Wraps a ext function with arguments given in the test file
This is similar to functools.wrap, but this makes sure that 'response' is
Expand All @@ -90,7 +93,7 @@ def get_wrapped_response_function(ext: Mapping):
extra_kwargs to pass
Returns:
function: Wrapped function
Wrapped function
"""

func, args, kwargs = _get_ext_values(ext)
Expand All @@ -106,7 +109,7 @@ def inner(response):
return inner


def get_wrapped_create_function(ext: Mapping):
def get_wrapped_create_function(ext: Mapping) -> Callable:
"""Same as get_wrapped_response_function, but don't require a response"""

func, args, kwargs = _get_ext_values(ext)
Expand All @@ -122,7 +125,7 @@ def inner():
return inner


def _get_ext_values(ext: Mapping):
def _get_ext_values(ext: Mapping) -> Tuple[Callable, Iterable, Mapping]:
if not isinstance(ext, Mapping):
raise exceptions.InvalidExtFunctionError(
f"ext block should be a dict, but it was a {type(ext)}"
Expand Down
6 changes: 3 additions & 3 deletions tavern/_core/general.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging
import os
from typing import List
from typing import List, Union

from tavern._core.loader import load_single_document_yaml

from .dict_util import deep_dict_merge

logger = logging.getLogger(__name__)
logger: logging.Logger = logging.getLogger(__name__)


def load_global_config(global_cfg_paths: List[os.PathLike]) -> dict:
def load_global_config(global_cfg_paths: List[Union[str, os.PathLike]]) -> dict:
"""Given a list of file paths to global config files, load each of them and
return the joined dictionary.
Expand Down
2 changes: 1 addition & 1 deletion tavern/_core/jmesutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def safe_length(var: Sized) -> int:
return -1


def validate_comparison(each_comparison):
def validate_comparison(each_comparison: Dict[Any, Any]):
if extra := set(each_comparison.keys()) - {"jmespath", "operator", "expected"}:
raise exceptions.BadSchemaError(
"Invalid keys given to JMES validation function (got extra keys: {})".format(
Expand Down
Loading

0 comments on commit d35a5ce

Please sign in to comment.