1
1
"""Deals with nodes which are dependencies or products of a task."""
2
2
import functools
3
3
import inspect
4
+ import itertools
4
5
import pathlib
5
6
from abc import ABCMeta
6
7
from abc import abstractmethod
13
14
from _pytask .exceptions import NodeNotCollectedError
14
15
from _pytask .exceptions import NodeNotFoundError
15
16
from _pytask .mark import get_marks_from_obj
16
- from _pytask .shared import to_list
17
+ from _pytask .shared import find_duplicates
17
18
18
19
19
20
def depends_on (objects : Union [Any , Iterable [Any ]]) -> Union [Any , Iterable [Any ]]:
@@ -68,22 +69,24 @@ class PythonFunctionTask(MetaTask):
68
69
"""pathlib.Path: Path to the file where the task was defined."""
69
70
function = attr .ib (type = callable )
70
71
"""callable: The task function."""
71
- depends_on = attr .ib (converter = to_list )
72
+ depends_on = attr .ib (factory = dict )
72
73
"""Optional[List[MetaNode]]: A list of dependencies of task."""
73
- produces = attr .ib (converter = to_list )
74
+ produces = attr .ib (factory = dict )
74
75
"""List[MetaNode]: A list of products of task."""
75
- markers = attr .ib ()
76
+ markers = attr .ib (factory = list )
76
77
"""Optional[List[Mark]]: A list of markers attached to the task function."""
77
78
_report_sections = attr .ib (factory = list )
78
79
79
80
@classmethod
80
81
def from_path_name_function_session (cls , path , name , function , session ):
81
82
"""Create a task from a path, name, function, and session."""
82
83
objects = _extract_nodes_from_function_markers (function , depends_on )
83
- dependencies = _collect_nodes (session , path , name , objects )
84
+ nodes = _convert_objects_to_node_dictionary (objects , "depends_on" )
85
+ dependencies = _collect_nodes (session , path , name , nodes )
84
86
85
87
objects = _extract_nodes_from_function_markers (function , produces )
86
- products = _collect_nodes (session , path , name , objects )
88
+ nodes = _convert_objects_to_node_dictionary (objects , "produces" )
89
+ products = _collect_nodes (session , path , name , nodes )
87
90
88
91
markers = [
89
92
marker
@@ -118,8 +121,10 @@ def _get_kwargs_from_task_for_function(self):
118
121
attribute = getattr (self , name )
119
122
kwargs [name ] = (
120
123
attribute [0 ].value
121
- if len (attribute ) == 1
122
- else [node .value for node in attribute ]
124
+ if len (attribute ) == 1 and 0 in attribute
125
+ else {
126
+ node_name : node .value for node_name , node in attribute .items ()
127
+ }
123
128
)
124
129
125
130
return kwargs
@@ -169,8 +174,9 @@ def state(self):
169
174
170
175
def _collect_nodes (session , path , name , nodes ):
171
176
"""Collect nodes for a task."""
172
- collect_nodes = []
173
- for node in nodes :
177
+ collected_nodes = {}
178
+
179
+ for node_name , node in nodes .items ():
174
180
collected_node = session .hook .pytask_collect_node (
175
181
session = session , path = path , node = node
176
182
)
@@ -180,9 +186,9 @@ def _collect_nodes(session, path, name, nodes):
180
186
f"'{ name } ' in '{ path } '."
181
187
)
182
188
else :
183
- collect_nodes . append ( collected_node )
189
+ collected_nodes [ node_name ] = collected_node
184
190
185
- return collect_nodes
191
+ return collected_nodes
186
192
187
193
188
194
def _extract_nodes_from_function_markers (function , parser ):
@@ -195,4 +201,82 @@ def _extract_nodes_from_function_markers(function, parser):
195
201
"""
196
202
marker_name = parser .__name__
197
203
for marker in get_marks_from_obj (function , marker_name ):
198
- yield from to_list (parser (* marker .args , ** marker .kwargs ))
204
+ parsed = parser (* marker .args , ** marker .kwargs )
205
+ yield parsed
206
+
207
+
208
+ def _convert_objects_to_node_dictionary (objects , when ):
209
+ list_of_tuples = _convert_objects_to_list_of_tuples (objects )
210
+ _check_that_names_are_not_used_multiple_times (list_of_tuples , when )
211
+ nodes = _convert_nodes_to_dictionary (list_of_tuples )
212
+ return nodes
213
+
214
+
215
+ def _convert_objects_to_list_of_tuples (objects ):
216
+ out = []
217
+ for obj in objects :
218
+ if isinstance (obj , dict ):
219
+ obj = obj .items ()
220
+
221
+ if isinstance (obj , Iterable ) and not isinstance (obj , str ):
222
+ for x in obj :
223
+ if isinstance (x , Iterable ) and not isinstance (x , str ):
224
+ tuple_x = tuple (x )
225
+ if len (tuple_x ) in [1 , 2 ]:
226
+ out .append (tuple_x )
227
+ else :
228
+ raise ValueError ("ERROR" )
229
+ else :
230
+ out .append ((x ,))
231
+ else :
232
+ out .append ((obj ,))
233
+
234
+ return out
235
+
236
+
237
+ def _check_that_names_are_not_used_multiple_times (list_of_tuples , when ):
238
+ """Check that names of nodes are not assigned multiple times.
239
+
240
+ Tuples in the list have either one or two elements. The first element in the two
241
+ element tuples is the name and cannot occur twice.
242
+
243
+ Examples
244
+ --------
245
+ >>> _check_that_names_are_not_used_multiple_times(
246
+ ... [("a",), ("a", 1)], "depends_on"
247
+ ... )
248
+ >>> _check_that_names_are_not_used_multiple_times(
249
+ ... [("a", 0), ("a", 1)], "produces"
250
+ ... )
251
+ Traceback (most recent call last):
252
+ ValueError: '@pytask.mark.produces' has nodes with the same name: {'a'}
253
+
254
+ """
255
+ names = [x [0 ] for x in list_of_tuples if len (x ) == 2 ]
256
+ duplicated = find_duplicates (names )
257
+
258
+ if duplicated :
259
+ raise ValueError (
260
+ f"'@pytask.mark.{ when } ' has nodes with the same name: { duplicated } "
261
+ )
262
+
263
+
264
+ def _convert_nodes_to_dictionary (list_of_tuples ):
265
+ nodes = {}
266
+ counter = itertools .count ()
267
+ names = [x [0 ] for x in list_of_tuples if len (x ) == 2 ]
268
+
269
+ for tuple_ in list_of_tuples :
270
+ if len (tuple_ ) == 2 :
271
+ node_name , node = tuple_
272
+ nodes [node_name ] = node
273
+
274
+ else :
275
+ while True :
276
+ node_name = next (counter )
277
+ if node_name not in names :
278
+ break
279
+
280
+ nodes [node_name ] = tuple_ [0 ]
281
+
282
+ return nodes
0 commit comments