4
4
import os
5
5
import re
6
6
import string
7
- from typing import Any , Dict , List , Mapping , Union
7
+ import typing
8
+ from collections .abc import Collection
9
+ from typing import Any , Dict , Iterator , List , Mapping , Optional , Tuple , Union
8
10
9
11
import box
10
12
import jmespath
22
24
from .formatted_str import FormattedString
23
25
from .strict_util import StrictSetting , StrictSettingKinds , extract_strict_setting
24
26
25
- logger = logging .getLogger (__name__ )
27
+ logger : logging . Logger = logging .getLogger (__name__ )
26
28
27
29
28
- def _check_and_format_values (to_format , box_vars : Mapping [str , Any ]) -> str :
30
+ def _check_and_format_values (to_format : str , box_vars : Mapping [str , Any ]) -> str :
29
31
formatter = string .Formatter ()
30
32
would_format = formatter .parse (to_format )
31
33
@@ -55,7 +57,7 @@ def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str:
55
57
return to_format .format (** box_vars )
56
58
57
59
58
- def _attempt_find_include (to_format : str , box_vars : box .Box ):
60
+ def _attempt_find_include (to_format : str , box_vars : box .Box ) -> Optional [ str ] :
59
61
formatter = string .Formatter ()
60
62
would_format = list (formatter .parse (to_format ))
61
63
@@ -89,32 +91,39 @@ def _attempt_find_include(to_format: str, box_vars: box.Box):
89
91
90
92
would_replace = formatter .get_field (field_name , [], box_vars )[0 ]
91
93
92
- return formatter .convert_field (would_replace , conversion ) # type: ignore
94
+ if conversion is None :
95
+ return would_replace
96
+
97
+ return formatter .convert_field (would_replace , conversion )
98
+
99
+
100
+ T = typing .TypeVar ("T" , str , Dict , List , Tuple )
93
101
94
102
95
103
def format_keys (
96
- val ,
97
- variables : Mapping ,
104
+ val : T ,
105
+ variables : Union [ Mapping , Box ] ,
98
106
* ,
99
107
no_double_format : bool = True ,
100
108
dangerously_ignore_string_format_errors : bool = False ,
101
- ):
109
+ ) -> T :
102
110
"""recursively format a dictionary with the given values
103
111
104
112
Args:
105
- val: Input dictionary to format
113
+ val: Input thing to format
106
114
variables: Dictionary of keys to format it with
107
115
no_double_format: Whether to use the 'inner formatted string' class to avoid double formatting
108
116
This is required if passing something via pytest-xdist, such as markers:
109
117
https://github.com/taverntesting/tavern/issues/431
110
118
dangerously_ignore_string_format_errors: whether to ignore any string formatting errors. This will result
111
119
in broken output, only use for debugging purposes.
112
120
121
+ Raises:
122
+ MissingFormatError: if a format variable was not found in variables
123
+
113
124
Returns:
114
125
recursively formatted values
115
126
"""
116
- formatted = val
117
-
118
127
format_keys_ = functools .partial (
119
128
format_keys ,
120
129
dangerously_ignore_string_format_errors = dangerously_ignore_string_format_errors ,
@@ -126,15 +135,15 @@ def format_keys(
126
135
box_vars = variables
127
136
128
137
if isinstance (val , dict ):
129
- formatted = {}
130
- # formatted = {key: format_keys(val[key], box_vars) for key in val}
131
- for key in val :
132
- formatted [key ] = format_keys_ (val [key ], box_vars )
133
- elif isinstance (val , (list , tuple )):
134
- formatted = [format_keys_ (item , box_vars ) for item in val ] # type: ignore
135
- elif isinstance (formatted , FormattedString ):
136
- logger .debug ("Already formatted %s, not double-formatting" , formatted )
138
+ return {key : format_keys_ (val [key ], box_vars ) for key in val }
139
+ elif isinstance (val , tuple ):
140
+ return tuple (format_keys_ (item , box_vars ) for item in val )
141
+ elif isinstance (val , list ):
142
+ return [format_keys_ (item , box_vars ) for item in val ]
143
+ elif isinstance (val , FormattedString ):
144
+ logger .debug ("Already formatted %s, not double-formatting" , val )
137
145
elif isinstance (val , str ):
146
+ formatted = val
138
147
try :
139
148
formatted = _check_and_format_values (val , box_vars )
140
149
except exceptions .MissingFormatError :
@@ -143,20 +152,22 @@ def format_keys(
143
152
144
153
if no_double_format :
145
154
formatted = FormattedString (formatted ) # type: ignore
155
+
156
+ return formatted
146
157
elif isinstance (val , TypeConvertToken ):
147
158
logger .debug ("Got type convert token '%s'" , val )
148
159
if isinstance (val , ForceIncludeToken ):
149
- formatted = _attempt_find_include (val .value , box_vars )
160
+ return _attempt_find_include (val .value , box_vars )
150
161
else :
151
162
value = format_keys_ (val .value , box_vars )
152
- formatted = val .constructor (value )
163
+ return val .constructor (value )
153
164
else :
154
- logger .debug ("Not formatting something of type '%s'" , type (formatted ))
165
+ logger .debug ("Not formatting something of type '%s'" , type (val ))
155
166
156
- return formatted
167
+ return val
157
168
158
169
159
- def recurse_access_key (data , query : str ):
170
+ def recurse_access_key (data : Union [ List , Mapping ], query : str ) -> Any :
160
171
"""
161
172
Search for something in the given data using the given query.
162
173
@@ -168,11 +179,14 @@ def recurse_access_key(data, query: str):
168
179
'c'
169
180
170
181
Args:
171
- data (dict, list): Data to search in
172
- query (str): Query to run
182
+ data: Data to search in
183
+ query: Query to run
184
+
185
+ Raises:
186
+ JMESError: if there was an error parsing the query
173
187
174
188
Returns:
175
- object: Whatever was found by the search
189
+ Whatever was found by the search
176
190
"""
177
191
178
192
try :
@@ -195,7 +209,9 @@ def recurse_access_key(data, query: str):
195
209
return from_jmespath
196
210
197
211
198
- def _deprecated_recurse_access_key (current_val , keys ):
212
+ def _deprecated_recurse_access_key (
213
+ current_val : Union [List , Mapping ], keys : List
214
+ ) -> Any :
199
215
"""Given a list of keys and a dictionary, recursively access the dicionary
200
216
using the keys until we find the key its looking for
201
217
@@ -209,15 +225,15 @@ def _deprecated_recurse_access_key(current_val, keys):
209
225
'c'
210
226
211
227
Args:
212
- current_val (dict) : current dictionary we have recursed into
213
- keys (list) : list of str/int of subkeys
228
+ current_val: current dictionary we have recursed into
229
+ keys: list of str/int of subkeys
214
230
215
231
Raises:
216
232
IndexError: list index not found in data
217
233
KeyError: dict key not found in data
218
234
219
235
Returns:
220
- str or dict: value of subkey in dict
236
+ value of subkey in dict
221
237
"""
222
238
logger .debug ("Recursively searching for '%s' in '%s'" , keys , current_val )
223
239
@@ -266,12 +282,12 @@ def deep_dict_merge(initial_dct: Dict, merge_dct: Mapping) -> dict:
266
282
return dct
267
283
268
284
269
- def check_expected_keys (expected , actual ) -> None :
285
+ def check_expected_keys (expected : Collection , actual : Collection ) -> None :
270
286
"""Check that a set of expected keys is a superset of the actual keys
271
287
272
288
Args:
273
- expected (list, set, dict) : keys we expect
274
- actual (list, set, dict) : keys we have got from the input
289
+ expected: keys we expect
290
+ actual: keys we have got from the input
275
291
276
292
Raises:
277
293
UnexpectedKeysError: If not actual <= expected
@@ -289,7 +305,7 @@ def check_expected_keys(expected, actual) -> None:
289
305
raise exceptions .UnexpectedKeysError (msg )
290
306
291
307
292
- def yield_keyvals (block ) :
308
+ def yield_keyvals (block : Union [ List , Dict ]) -> Iterator [ Tuple [ List , str , str ]] :
293
309
"""Return indexes, keys and expected values for matching recursive keys
294
310
295
311
Given a list or dict, return a 3-tuple of the 'split' key (key split on
@@ -321,10 +337,10 @@ def yield_keyvals(block):
321
337
(['2'], '2', 'c')
322
338
323
339
Args:
324
- block (dict, list) : input matches
340
+ block: input matches
325
341
326
342
Yields:
327
- (list, str, str): key split on dots, key, expected value
343
+ iterable of ( key split on dots, key, expected value)
328
344
"""
329
345
if isinstance (block , dict ):
330
346
for joined_key , expected_val in block .items ():
@@ -336,9 +352,12 @@ def yield_keyvals(block):
336
352
yield [sidx ], sidx , val
337
353
338
354
355
+ Checked = typing .TypeVar ("Checked" , Dict , Collection , str )
356
+
357
+
339
358
def check_keys_match_recursive (
340
- expected_val : Any ,
341
- actual_val : Any ,
359
+ expected_val : Checked ,
360
+ actual_val : Checked ,
342
361
keys : List [Union [str , int ]],
343
362
strict : StrictSettingKinds = True ,
344
363
) -> None :
@@ -443,8 +462,8 @@ def _format_err(which):
443
462
raise exceptions .KeyMismatchError (msg ) from e
444
463
445
464
if isinstance (expected_val , dict ):
446
- akeys = set (actual_val .keys ())
447
465
ekeys = set (expected_val .keys ())
466
+ akeys = set (actual_val .keys ()) # type:ignore
448
467
449
468
if akeys != ekeys :
450
469
extra_actual_keys = akeys - ekeys
@@ -481,7 +500,10 @@ def _format_err(which):
481
500
for key in to_recurse :
482
501
try :
483
502
check_keys_match_recursive (
484
- expected_val [key ], actual_val [key ], keys + [key ], strict
503
+ expected_val [key ],
504
+ actual_val [key ], # type:ignore
505
+ keys + [key ],
506
+ strict ,
485
507
)
486
508
except KeyError :
487
509
logger .debug (
0 commit comments