Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

no-op #6895

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

no-op #6895

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,8 @@ def _generate_tasks_for_node(
execution_type=node.node_info.type,
contexts=resolved_info.contexts,
input_and_params=unprocessed_inputs,
pipeline=self._pipeline,
node_id=node.node_info.id,
)

for execution in executions:
Expand Down
8 changes: 8 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,3 +1673,11 @@ def get_pipeline_and_node(
'pipeline nodes are supported for external executions.'
)
return (pipeline_state.pipeline, node)


def get_pipeline(
mlmd_handle: metadata.Metadata, pipeline_id: str
) -> pipeline_pb2.Pipeline:
"""Loads the pipeline proto for a pipeline from latest execution."""
pipeline_view = PipelineView.load(mlmd_handle, pipeline_id)
return pipeline_view.pipeline
2 changes: 2 additions & 0 deletions tfx/orchestration/experimental/core/sync_pipeline_task_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@ def _generate_tasks_from_resolved_inputs(
execution_type=node.node_info.type,
contexts=resolved_info.contexts,
input_and_params=resolved_info.input_and_params,
pipeline=self._pipeline,
node_id=node.node_info.id,
)

result.extend(
Expand Down
61 changes: 52 additions & 9 deletions tfx/orchestration/experimental/core/task_gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tfx.orchestration import metadata
from tfx.orchestration import node_proto_view
from tfx.orchestration.experimental.core import constants
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import mlmd_state
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
Expand Down Expand Up @@ -548,21 +549,41 @@ def register_executions_from_existing_executions(
contexts = metadata_handle.store.get_contexts_by_execution(
existing_executions[0].id
)
return execution_lib.put_executions(
executions = execution_lib.put_executions(
metadata_handle,
new_executions,
contexts,
input_artifacts_maps=input_artifacts,
)

pipeline_asset = metadata_handle.store.pipeline_asset
if pipeline_asset:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
pipeline,
node.node_info.id,
executions,
)
else:
logging.warning(
'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
' node executions.',
pipeline_asset,
)
return executions


# TODO(b/349654866): make pipeline and node_id non-optional.
def register_executions(
metadata_handle: metadata.Metadata,
execution_type: metadata_store_pb2.ExecutionType,
contexts: Sequence[metadata_store_pb2.Context],
input_and_params: Sequence[InputAndParam],
pipeline: Optional[pipeline_pb2.Pipeline] = None,
node_id: Optional[str] = None,
) -> Sequence[metadata_store_pb2.Execution]:
"""Registers multiple executions in MLMD.
"""Registers multiple executions in storage backends.

Along with the execution:
- the input artifacts will be linked to the executions.
Expand All @@ -575,6 +596,8 @@ def register_executions(
input_and_params: A list of InputAndParams, which includes input_dicts
(dictionaries of artifacts. One execution will be registered for each of
the input_dict) and corresponding exec_properties.
pipeline: Optional. The pipeline proto.
node_id: Optional. The node id of the executions to be registered.

Returns:
A list of MLMD executions that are registered in MLMD, with id populated.
Expand Down Expand Up @@ -603,21 +626,41 @@ def register_executions(
executions.append(execution)

if len(executions) == 1:
return [
new_executions = [
execution_lib.put_execution(
metadata_handle,
executions[0],
contexts,
input_artifacts=input_and_params[0].input_artifacts,
)
]
else:
new_executions = execution_lib.put_executions(
metadata_handle,
executions,
contexts,
[
input_and_param.input_artifacts
for input_and_param in input_and_params
],
)

return execution_lib.put_executions(
metadata_handle,
executions,
contexts,
[input_and_param.input_artifacts for input_and_param in input_and_params],
)
pipeline_asset = metadata_handle.store.pipeline_asset
if pipeline_asset and pipeline and node_id:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
pipeline,
node_id,
new_executions,
)
else:
logging.warning(
'Skipping creating pipeline run node executions for pipeline asset %s.',
pipeline_asset,
)

return new_executions


def update_external_artifact_type(
Expand Down
7 changes: 5 additions & 2 deletions tfx/orchestration/portable/execution_publish_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def publish_cached_executions(
output_artifacts_maps: Optional[
Sequence[typing_utils.ArtifactMultiMap]
] = None,
) -> None:
) -> Sequence[metadata_store_pb2.Execution]:
"""Marks an existing execution as using cached outputs from a previous execution.

Args:
Expand All @@ -46,11 +46,14 @@ def publish_cached_executions(
executions: Executions that will be published as CACHED executions.
output_artifacts_maps: A list of output artifacts of the executions. Each
artifact will be linked with the execution through an event of type OUTPUT

Returns:
A list of MLMD executions that are published to MLMD, with id pupulated.
"""
for execution in executions:
execution.last_known_state = metadata_store_pb2.Execution.CACHED

execution_lib.put_executions(
return execution_lib.put_executions(
metadata_handle,
executions,
contexts,
Expand Down
22 changes: 20 additions & 2 deletions tfx/orchestration/portable/importer_node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from tfx.dsl.components.common import importer
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.portable import data_types
from tfx.orchestration.portable import execution_publish_utils
from tfx.orchestration.portable import inputs_utils
Expand Down Expand Up @@ -57,7 +59,7 @@ def run(

Args:
mlmd_connection: ML metadata connection.
pipeline_node: The specification of the node that this launcher lauches.
pipeline_node: The specification of the node that this launcher launches.
pipeline_info: The information of the pipeline that this node runs in.
pipeline_runtime_spec: The runtime information of the pipeline that this
node runs in.
Expand All @@ -78,13 +80,29 @@ def run(
inputs_utils.resolve_parameters_with_schema(
node_parameters=pipeline_node.parameters))

# 3. Registers execution in metadata.
# 3. Registers execution in storage backend.
execution = execution_publish_utils.register_execution(
metadata_handle=m,
execution_type=pipeline_node.node_info.type,
contexts=contexts,
exec_properties=exec_properties,
)
pipeline_asset = m.store.pipeline_asset
if pipeline_asset:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
pstate.get_pipeline(m, pipeline_info.id),
pipeline_node.node_info.id,
[execution],
)
else:
logging.warning(
'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
' node execution %s.',
pipeline_asset,
execution,
)

# 4. Generate output artifacts to represent the imported artifacts.
output_key = cast(str, exec_properties[importer.OUTPUT_KEY_KEY])
Expand Down
46 changes: 39 additions & 7 deletions tfx/orchestration/portable/partial_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tfx.dsl.compiler import constants
from tfx.orchestration import metadata
from tfx.orchestration import node_proto_view
from tfx.orchestration.experimental.core import env
from tfx.orchestration.portable import execution_publish_utils
from tfx.orchestration.portable.mlmd import context_lib
from tfx.orchestration.portable.mlmd import execution_lib
Expand Down Expand Up @@ -599,6 +600,8 @@ def __init__(
for node in node_proto_view.get_view_for_all_in(new_pipeline_run_ir)
}

self._pipeline = new_pipeline_run_ir

def _get_base_pipeline_run_context(
self, base_run_id: Optional[str] = None
) -> metadata_store_pb2.Context:
Expand Down Expand Up @@ -788,7 +791,12 @@ def _cache_and_publish(
contexts=[self._new_pipeline_run_context] + node_contexts,
)
)
if not prev_cache_executions:

# If there are no previous attempts to cache and publish, we will create new
# cache executions.
create_new_cache_executions: bool = not prev_cache_executions

if create_new_cache_executions:
new_cached_executions = []
for e in existing_executions:
new_cached_executions.append(
Expand Down Expand Up @@ -820,12 +828,36 @@ def _cache_and_publish(
execution_lib.get_output_artifacts(self._mlmd, e.id)
for e in existing_executions
]
execution_publish_utils.publish_cached_executions(
self._mlmd,
contexts=cached_execution_contexts,
executions=new_cached_executions,
output_artifacts_maps=output_artifacts_maps,
)

if create_new_cache_executions:
new_executions = execution_publish_utils.publish_cached_executions(
self._mlmd,
contexts=cached_execution_contexts,
executions=new_cached_executions,
output_artifacts_maps=output_artifacts_maps,
)
pipeline_asset = self._mlmd.store.pipeline_asset
if pipeline_asset:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
self._pipeline,
node.node_info.id,
new_executions,
)
else:
logging.warning(
'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
' node executions.',
pipeline_asset,
)
else:
execution_publish_utils.publish_cached_executions(
self._mlmd,
contexts=cached_execution_contexts,
executions=new_cached_executions,
output_artifacts_maps=output_artifacts_maps,
)

def put_parent_context(self):
"""Puts a ParentContext edge in MLMD."""
Expand Down
37 changes: 35 additions & 2 deletions tfx/orchestration/portable/resolver_node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import grpc
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.portable import data_types
from tfx.orchestration.portable import execution_publish_utils
from tfx.orchestration.portable import inputs_utils
Expand Down Expand Up @@ -86,6 +88,22 @@ def run(
contexts=contexts,
exec_properties=exec_properties,
)
pipeline_asset = m.store.pipeline_asset
if pipeline_asset:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
pstate.get_pipeline(m, pipeline_info.id),
pipeline_node.node_info.id,
[execution],
)
else:
logging.warning(
'Pipeline asset %s not found in MLMD. Unable to create pipeline'
' run node execution %s.',
pipeline_asset,
execution,
)
execution_publish_utils.publish_failed_execution(
metadata_handle=m,
contexts=contexts,
Expand All @@ -103,14 +121,29 @@ def run(
if isinstance(resolved_inputs, inputs_utils.Skip):
return data_types.ExecutionInfo()

# 3. Registers execution in metadata.
# 3. Registers execution in storage backends.
execution = execution_publish_utils.register_execution(
metadata_handle=m,
execution_type=pipeline_node.node_info.type,
contexts=contexts,
exec_properties=exec_properties,
)

pipeline_asset = m.store.pipeline_asset
if pipeline_asset:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
pstate.get_pipeline(m, pipeline_info.id),
pipeline_node.node_info.id,
[execution],
)
else:
logging.warning(
'Pipeline asset %s not found in MLMD. Unable to create pipeline'
' run node execution %s.',
pipeline_asset,
execution,
)
# TODO(b/197741942): Support len > 1.
if len(resolved_inputs) > 1:
execution_publish_utils.publish_failed_execution(
Expand Down