Skip to content

Commit

Permalink
Add specialized function in the client to get proper return types
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-intel committed Oct 20, 2023
1 parent 31a9048 commit 93e511c
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 21 deletions.
176 changes: 158 additions & 18 deletions metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@
from datetime import datetime
from io import BytesIO
from itertools import chain
from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Optional, Tuple
from typing import (
Any,
Dict,
FrozenSet,
Iterable,
Iterator,
List,
NamedTuple,
Optional,
Tuple,
)

from metaflow.current import current
from metaflow.events import Trigger
Expand Down Expand Up @@ -340,12 +350,12 @@ def _get_object(self, *path_components):
raise MetaflowNotFound("%s does not exist" % self)
return result

def __iter__(self) -> Iterable["MetaflowObject"]:
def __iter__(self) -> Iterator["MetaflowObject"]:
"""
Iterate over all child objects of this object if any.
Note that only children present in the current namespace are returned iff
_namespace_check is set.
Note that only children present in the current namespace are returned if and
only if _namespace_check is set.
Yields
------
Expand Down Expand Up @@ -1445,7 +1455,7 @@ def loglines(
stream: str,
as_unicode: bool = True,
meta_dict: Optional[Dict[str, Any]] = None,
) -> Iterable[Tuple[datetime, str]]:
) -> Iterator[Tuple[datetime, str]]:
"""
Return an iterator over (utc_timestamp, logline) tuples.
Expand Down Expand Up @@ -1531,6 +1541,39 @@ def _log_size(self, stream, meta_dict):
ds_type, ds_root, stream, attempt, *self.path_components
)

def __iter__(self) -> Iterator[DataArtifact]:
"""
Iterate over all children DataArtifact of this Task
Yields
------
DataArtifact
A DataArtifact in this Step
"""
for d in super(Task, self).__iter__():
yield d

def __getitem__(self, name: str) -> DataArtifact:
"""
Returns the DataArtifact object with the artifact name 'name'
Parameters
----------
name : str
Data artifact name
Returns
-------
DataArtifact
DataArtifact for this artifact name in this task
Raises
------
KeyError
If the name does not identify a valid DataArtifact object
"""
return super(Task, self).__getitem__(name)

def __getstate__(self):
return super(Task, self).__getstate__()

Expand Down Expand Up @@ -1613,7 +1656,7 @@ def control_task(self) -> Optional[Task]:
"""
return next(self.control_tasks(), None)

def control_tasks(self, *tags: str) -> Iterable[Task]:
def control_tasks(self, *tags: str) -> Iterator[Task]:
"""
[Unpublished API - use with caution!]
Expand Down Expand Up @@ -1650,11 +1693,39 @@ def control_tasks(self, *tags: str) -> Iterable[Task]:
):
yield child

def __iter__(self):
children = super(Step, self).__iter__()
for t in children:
def __iter__(self) -> Iterator[Task]:
"""
Iterate over all children Task of this Step
Yields
------
Task
A Task in this Step
"""
for t in super(Step, self).__iter__():
yield t

def __getitem__(self, task_id: str) -> Task:
"""
Returns the Task object with the task ID 'task_id'
Parameters
----------
task_id : str
Task ID
Returns
-------
Task
Task for this task ID in this Step
Raises
------
KeyError
If the task_id does not identify a valid Task object
"""
return super(Step, self).__getitem__(task_id)

def __getstate__(self):
return super(Step, self).__getstate__()

Expand Down Expand Up @@ -1730,7 +1801,7 @@ def _iter_filter(self, x):
# exclude _parameters step
return x.id[0] != "_"

def steps(self, *tags: str) -> Iterable[Step]:
def steps(self, *tags: str) -> Iterator[Step]:
"""
[Legacy function - do not use]
Expand Down Expand Up @@ -1983,6 +2054,39 @@ def replace_tags(self, tags_to_remove: Iterable[str], tags_to_add: Iterable[str]
self._user_tags = frozenset(final_user_tags)
self._tags = frozenset([*self._user_tags, *self._system_tags])

def __iter__(self) -> Iterator[Step]:
"""
Iterate over all children Step of this Run
Yields
------
Step
A Step in this Run
"""
for s in super(Run, self).__iter__():
yield s

def __getitem__(self, name: str) -> Step:
"""
Returns the Step object with the step name 'name'
Parameters
----------
name : str
Step name
Returns
-------
Step
Step for this step name in this Run
Raises
------
KeyError
If the name does not identify a valid Step object
"""
return super(Run, self).__getitem__(name)

def __getstate__(self):
return super(Run, self).__getstate__()

Expand Down Expand Up @@ -2058,7 +2162,7 @@ def latest_successful_run(self) -> Optional[Run]:
if run.successful:
return run

def runs(self, *tags: str) -> Iterable[Run]:
def runs(self, *tags: str) -> Iterator[Run]:
"""
Returns an iterator over all `Run`s of this flow.
Expand All @@ -2078,6 +2182,42 @@ def runs(self, *tags: str) -> Iterable[Run]:
"""
return self._filtered_children(*tags)

def __iter__(self) -> Iterator[Task]:
"""
Iterate over all children Run of this Flow.
Note that only runs in the current namespace are returned unless
_namespace_check is False
Yields
------
Run
A Run in this Flow
"""
for r in super(Flow, self).__iter__():
yield r

def __getitem__(self, run_id: str) -> Run:
"""
Returns the Run object with the run ID 'run_id'
Parameters
----------
run_id : str
Run OD
Returns
-------
Run
Run for this run ID in this Flow
Raises
------
KeyError
If the run_id does not identify a valid Run object
"""
return super(Flow, self).__getitem__(run_id)

def __getstate__(self):
return super(Flow, self).__getstate__()

Expand Down Expand Up @@ -2127,12 +2267,12 @@ def flows(self) -> List[Flow]:
"""
return list(self)

def __iter__(self):
def __iter__(self) -> Iterator[Flow]:
"""
Iterator over all flows present.
Only flows present in the set namespace are returned. A flow is present in a namespace if
it has at least one run that is in the namespace.
Only flows present in the set namespace are returned. A flow is present in a
namespace if it has at least one run that is in the namespace.
Yields
-------
Expand All @@ -2152,24 +2292,24 @@ def __iter__(self):
except MetaflowNamespaceMismatch:
continue

def __str__(self):
def __str__(self) -> str:
return "Metaflow()"

def __getitem__(self, id: str) -> Flow:
def __getitem__(self, name: str) -> Flow:
"""
Returns a specific flow by name.
The flow will only be returned if it is present in the current namespace.
Parameters
----------
id : str
name : str
Name of the Flow
Returns
-------
Flow
Flow with the given ID.
Flow with the given name.
"""
return Flow(id)

Expand Down
2 changes: 1 addition & 1 deletion metaflow/flowspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def next(self, *dsts: Callable[..., None], **kwargs) -> None:
Parameters
----------
dsts : Method
dsts : Callable[..., None]
One or more methods annotated with `@step`.
Raises
Expand Down
4 changes: 2 additions & 2 deletions metaflow/multicore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tempfile import NamedTemporaryFile
import time

from typing import Any, Callable, Iterable, List, Optional
from typing import Any, Callable, Iterable, Iterator, List, Optional

try:
# Python 2
Expand Down Expand Up @@ -66,7 +66,7 @@ def parallel_imap_unordered(
iterable: Iterable[Any],
max_parallel: Optional[int] = None,
dir: Optional[str] = None,
) -> Iterable[Any]:
) -> Iterator[Any]:
"""
Parallelizes execution of a function using multiprocessing. The result
order is not guaranteed.
Expand Down

0 comments on commit 93e511c

Please sign in to comment.