Skip to content

Commit d35a5ce

Browse files
Generic cleanup + add py38 annotations (#911)
1 parent 509a0c5 commit d35a5ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+515
-397
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ repos:
1616
hooks:
1717
- id: pycln
1818
- repo: https://github.com/charliermarsh/ruff-pre-commit
19-
rev: "v0.1.11"
19+
rev: "v0.1.14"
2020
hooks:
2121
- id: ruff-format
2222
- id: ruff

scripts/smoke.bash

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ pre-commit run ruff --all-files || true
66
pre-commit run ruff-format --all-files || true
77

88
tox --parallel -c tox.ini \
9-
-e py3check
9+
-e py3mypy
1010

1111
tox --parallel -c tox.ini \
12-
-e py3mypy
12+
-e py3check
1313

1414
tox --parallel -c tox.ini \
1515
-e py3

tavern/_core/dict_util.py

Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import os
55
import re
66
import string
7-
from typing import Any, Dict, List, Mapping, Union
7+
import typing
8+
from collections.abc import Collection
9+
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union
810

911
import box
1012
import jmespath
@@ -22,10 +24,10 @@
2224
from .formatted_str import FormattedString
2325
from .strict_util import StrictSetting, StrictSettingKinds, extract_strict_setting
2426

25-
logger = logging.getLogger(__name__)
27+
logger: logging.Logger = logging.getLogger(__name__)
2628

2729

28-
def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str:
30+
def _check_and_format_values(to_format: str, box_vars: Mapping[str, Any]) -> str:
2931
formatter = string.Formatter()
3032
would_format = formatter.parse(to_format)
3133

@@ -55,7 +57,7 @@ def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str:
5557
return to_format.format(**box_vars)
5658

5759

58-
def _attempt_find_include(to_format: str, box_vars: box.Box):
60+
def _attempt_find_include(to_format: str, box_vars: box.Box) -> Optional[str]:
5961
formatter = string.Formatter()
6062
would_format = list(formatter.parse(to_format))
6163

@@ -89,32 +91,39 @@ def _attempt_find_include(to_format: str, box_vars: box.Box):
8991

9092
would_replace = formatter.get_field(field_name, [], box_vars)[0]
9193

92-
return formatter.convert_field(would_replace, conversion) # type: ignore
94+
if conversion is None:
95+
return would_replace
96+
97+
return formatter.convert_field(would_replace, conversion)
98+
99+
100+
T = typing.TypeVar("T", str, Dict, List, Tuple)
93101

94102

95103
def format_keys(
96-
val,
97-
variables: Mapping,
104+
val: T,
105+
variables: Union[Mapping, Box],
98106
*,
99107
no_double_format: bool = True,
100108
dangerously_ignore_string_format_errors: bool = False,
101-
):
109+
) -> T:
102110
"""recursively format a dictionary with the given values
103111
104112
Args:
105-
val: Input dictionary to format
113+
val: Input thing to format
106114
variables: Dictionary of keys to format it with
107115
no_double_format: Whether to use the 'inner formatted string' class to avoid double formatting
108116
This is required if passing something via pytest-xdist, such as markers:
109117
https://github.com/taverntesting/tavern/issues/431
110118
dangerously_ignore_string_format_errors: whether to ignore any string formatting errors. This will result
111119
in broken output, only use for debugging purposes.
112120
121+
Raises:
122+
MissingFormatError: if a format variable was not found in variables
123+
113124
Returns:
114125
recursively formatted values
115126
"""
116-
formatted = val
117-
118127
format_keys_ = functools.partial(
119128
format_keys,
120129
dangerously_ignore_string_format_errors=dangerously_ignore_string_format_errors,
@@ -126,15 +135,15 @@ def format_keys(
126135
box_vars = variables
127136

128137
if isinstance(val, dict):
129-
formatted = {}
130-
# formatted = {key: format_keys(val[key], box_vars) for key in val}
131-
for key in val:
132-
formatted[key] = format_keys_(val[key], box_vars)
133-
elif isinstance(val, (list, tuple)):
134-
formatted = [format_keys_(item, box_vars) for item in val] # type: ignore
135-
elif isinstance(formatted, FormattedString):
136-
logger.debug("Already formatted %s, not double-formatting", formatted)
138+
return {key: format_keys_(val[key], box_vars) for key in val}
139+
elif isinstance(val, tuple):
140+
return tuple(format_keys_(item, box_vars) for item in val)
141+
elif isinstance(val, list):
142+
return [format_keys_(item, box_vars) for item in val]
143+
elif isinstance(val, FormattedString):
144+
logger.debug("Already formatted %s, not double-formatting", val)
137145
elif isinstance(val, str):
146+
formatted = val
138147
try:
139148
formatted = _check_and_format_values(val, box_vars)
140149
except exceptions.MissingFormatError:
@@ -143,20 +152,22 @@ def format_keys(
143152

144153
if no_double_format:
145154
formatted = FormattedString(formatted) # type: ignore
155+
156+
return formatted
146157
elif isinstance(val, TypeConvertToken):
147158
logger.debug("Got type convert token '%s'", val)
148159
if isinstance(val, ForceIncludeToken):
149-
formatted = _attempt_find_include(val.value, box_vars)
160+
return _attempt_find_include(val.value, box_vars)
150161
else:
151162
value = format_keys_(val.value, box_vars)
152-
formatted = val.constructor(value)
163+
return val.constructor(value)
153164
else:
154-
logger.debug("Not formatting something of type '%s'", type(formatted))
165+
logger.debug("Not formatting something of type '%s'", type(val))
155166

156-
return formatted
167+
return val
157168

158169

159-
def recurse_access_key(data, query: str):
170+
def recurse_access_key(data: Union[List, Mapping], query: str) -> Any:
160171
"""
161172
Search for something in the given data using the given query.
162173
@@ -168,11 +179,14 @@ def recurse_access_key(data, query: str):
168179
'c'
169180
170181
Args:
171-
data (dict, list): Data to search in
172-
query (str): Query to run
182+
data: Data to search in
183+
query: Query to run
184+
185+
Raises:
186+
JMESError: if there was an error parsing the query
173187
174188
Returns:
175-
object: Whatever was found by the search
189+
Whatever was found by the search
176190
"""
177191

178192
try:
@@ -195,7 +209,9 @@ def recurse_access_key(data, query: str):
195209
return from_jmespath
196210

197211

198-
def _deprecated_recurse_access_key(current_val, keys):
212+
def _deprecated_recurse_access_key(
213+
current_val: Union[List, Mapping], keys: List
214+
) -> Any:
199215
"""Given a list of keys and a dictionary, recursively access the dicionary
200216
using the keys until we find the key its looking for
201217
@@ -209,15 +225,15 @@ def _deprecated_recurse_access_key(current_val, keys):
209225
'c'
210226
211227
Args:
212-
current_val (dict): current dictionary we have recursed into
213-
keys (list): list of str/int of subkeys
228+
current_val: current dictionary we have recursed into
229+
keys: list of str/int of subkeys
214230
215231
Raises:
216232
IndexError: list index not found in data
217233
KeyError: dict key not found in data
218234
219235
Returns:
220-
str or dict: value of subkey in dict
236+
value of subkey in dict
221237
"""
222238
logger.debug("Recursively searching for '%s' in '%s'", keys, current_val)
223239

@@ -266,12 +282,12 @@ def deep_dict_merge(initial_dct: Dict, merge_dct: Mapping) -> dict:
266282
return dct
267283

268284

269-
def check_expected_keys(expected, actual) -> None:
285+
def check_expected_keys(expected: Collection, actual: Collection) -> None:
270286
"""Check that a set of expected keys is a superset of the actual keys
271287
272288
Args:
273-
expected (list, set, dict): keys we expect
274-
actual (list, set, dict): keys we have got from the input
289+
expected: keys we expect
290+
actual: keys we have got from the input
275291
276292
Raises:
277293
UnexpectedKeysError: If not actual <= expected
@@ -289,7 +305,7 @@ def check_expected_keys(expected, actual) -> None:
289305
raise exceptions.UnexpectedKeysError(msg)
290306

291307

292-
def yield_keyvals(block):
308+
def yield_keyvals(block: Union[List, Dict]) -> Iterator[Tuple[List, str, str]]:
293309
"""Return indexes, keys and expected values for matching recursive keys
294310
295311
Given a list or dict, return a 3-tuple of the 'split' key (key split on
@@ -321,10 +337,10 @@ def yield_keyvals(block):
321337
(['2'], '2', 'c')
322338
323339
Args:
324-
block (dict, list): input matches
340+
block: input matches
325341
326342
Yields:
327-
(list, str, str): key split on dots, key, expected value
343+
iterable of (key split on dots, key, expected value)
328344
"""
329345
if isinstance(block, dict):
330346
for joined_key, expected_val in block.items():
@@ -336,9 +352,12 @@ def yield_keyvals(block):
336352
yield [sidx], sidx, val
337353

338354

355+
Checked = typing.TypeVar("Checked", Dict, Collection, str)
356+
357+
339358
def check_keys_match_recursive(
340-
expected_val: Any,
341-
actual_val: Any,
359+
expected_val: Checked,
360+
actual_val: Checked,
342361
keys: List[Union[str, int]],
343362
strict: StrictSettingKinds = True,
344363
) -> None:
@@ -443,8 +462,8 @@ def _format_err(which):
443462
raise exceptions.KeyMismatchError(msg) from e
444463

445464
if isinstance(expected_val, dict):
446-
akeys = set(actual_val.keys())
447465
ekeys = set(expected_val.keys())
466+
akeys = set(actual_val.keys()) # type:ignore
448467

449468
if akeys != ekeys:
450469
extra_actual_keys = akeys - ekeys
@@ -481,7 +500,10 @@ def _format_err(which):
481500
for key in to_recurse:
482501
try:
483502
check_keys_match_recursive(
484-
expected_val[key], actual_val[key], keys + [key], strict
503+
expected_val[key],
504+
actual_val[key], # type:ignore
505+
keys + [key],
506+
strict,
485507
)
486508
except KeyError:
487509
logger.debug(

tavern/_core/extfunctions.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import importlib
33
import logging
4-
from typing import Any, List, Mapping, Optional
4+
from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple
55

66
from tavern._core import exceptions
77

@@ -16,7 +16,7 @@ def is_ext_function(block: Any) -> bool:
1616
block: Any object
1717
1818
Returns:
19-
bool: If it is an ext function style dict
19+
If it is an ext function style dict
2020
"""
2121
return isinstance(block, dict) and block.get("$ext", None) is not None
2222

@@ -29,17 +29,20 @@ def get_pykwalify_logger(module: Optional[str]) -> logging.Logger:
2929
trying to get the root logger which won't log correctly
3030
3131
Args:
32-
module (string): name of module to get logger for
32+
module: name of module to get logger for
3333
34+
Returns:
35+
logger for given module
3436
"""
3537
return logging.getLogger(module)
3638

3739

3840
def _getlogger() -> logging.Logger:
41+
"""Get logger for this module"""
3942
return get_pykwalify_logger("tavern._core.extfunctions")
4043

4144

42-
def import_ext_function(entrypoint: str):
45+
def import_ext_function(entrypoint: str) -> Callable:
4346
"""Given a function name in the form of a setuptools entry point, try to
4447
dynamically load and return it
4548
@@ -48,7 +51,7 @@ def import_ext_function(entrypoint: str):
4851
module.submodule:function
4952
5053
Returns:
51-
function: function loaded from entrypoint
54+
function loaded from entrypoint
5255
5356
Raises:
5457
InvalidExtFunctionError: If the module or function did not exist
@@ -79,7 +82,7 @@ def import_ext_function(entrypoint: str):
7982
return function
8083

8184

82-
def get_wrapped_response_function(ext: Mapping):
85+
def get_wrapped_response_function(ext: Mapping) -> Callable:
8386
"""Wraps a ext function with arguments given in the test file
8487
8588
This is similar to functools.wrap, but this makes sure that 'response' is
@@ -90,7 +93,7 @@ def get_wrapped_response_function(ext: Mapping):
9093
extra_kwargs to pass
9194
9295
Returns:
93-
function: Wrapped function
96+
Wrapped function
9497
"""
9598

9699
func, args, kwargs = _get_ext_values(ext)
@@ -106,7 +109,7 @@ def inner(response):
106109
return inner
107110

108111

109-
def get_wrapped_create_function(ext: Mapping):
112+
def get_wrapped_create_function(ext: Mapping) -> Callable:
110113
"""Same as get_wrapped_response_function, but don't require a response"""
111114

112115
func, args, kwargs = _get_ext_values(ext)
@@ -122,7 +125,7 @@ def inner():
122125
return inner
123126

124127

125-
def _get_ext_values(ext: Mapping):
128+
def _get_ext_values(ext: Mapping) -> Tuple[Callable, Iterable, Mapping]:
126129
if not isinstance(ext, Mapping):
127130
raise exceptions.InvalidExtFunctionError(
128131
f"ext block should be a dict, but it was a {type(ext)}"

tavern/_core/general.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import logging
22
import os
3-
from typing import List
3+
from typing import List, Union
44

55
from tavern._core.loader import load_single_document_yaml
66

77
from .dict_util import deep_dict_merge
88

9-
logger = logging.getLogger(__name__)
9+
logger: logging.Logger = logging.getLogger(__name__)
1010

1111

12-
def load_global_config(global_cfg_paths: List[os.PathLike]) -> dict:
12+
def load_global_config(global_cfg_paths: List[Union[str, os.PathLike]]) -> dict:
1313
"""Given a list of file paths to global config files, load each of them and
1414
return the joined dictionary.
1515

tavern/_core/jmesutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def safe_length(var: Sized) -> int:
6161
return -1
6262

6363

64-
def validate_comparison(each_comparison):
64+
def validate_comparison(each_comparison: Dict[Any, Any]):
6565
if extra := set(each_comparison.keys()) - {"jmespath", "operator", "expected"}:
6666
raise exceptions.BadSchemaError(
6767
"Invalid keys given to JMES validation function (got extra keys: {})".format(

0 commit comments

Comments
 (0)