diff --git a/flytekit/core/node.py b/flytekit/core/node.py index ea089c6fd3..73b19526e6 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -194,6 +194,8 @@ def with_overrides( if name is not None: self._metadata._name = name + if self.run_entity and hasattr(self.run_entity, "metadata"): + self.run_entity.metadata.name = name if task_config is not None: logger.warning("This override is beta. We may want to revisit this in the future.") @@ -212,14 +214,20 @@ def with_overrides( if cache is not None: assert_not_promise(cache, "cache") self._metadata._cacheable = cache + if hasattr(self.run_entity, "metadata") and self.run_entity.metadata is not None: + self.run_entity.metadata.cache = cache if cache_version is not None: assert_not_promise(cache_version, "cache_version") self._metadata._cache_version = cache_version + if hasattr(self.run_entity, "metadata") and self.run_entity.metadata is not None: + self.run_entity.metadata.cache_version = cache_version if cache_serialize is not None: assert_not_promise(cache_serialize, "cache_serialize") self._metadata._cache_serializable = cache_serialize + if hasattr(self.run_entity, "metadata") and self.run_entity.metadata is not None: + self.run_entity.metadata.cache_serialize = cache_serialize return self diff --git a/flytekit/tools/serialize_helpers.py b/flytekit/tools/serialize_helpers.py index 8d4cfcb99c..0231b77374 100644 --- a/flytekit/tools/serialize_helpers.py +++ b/flytekit/tools/serialize_helpers.py @@ -50,9 +50,14 @@ def get_registrable_entities( that are not known to Admin """ new_api_serializable_entities = OrderedDict() + + # Sort entities to process workflows and launch plans before tasks # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan # object, which gets added to the FlyteEntities.entities list, which we're iterating over. - for entity in flyte_context.FlyteEntities.entities.copy(): + sorted_entities = sorted( + flyte_context.FlyteEntities.entities.copy(), key=lambda x: 0 if isinstance(x, (WorkflowBase, LaunchPlan)) else 1 + ) + for entity in sorted_entities: if isinstance(entity, PythonTask) or isinstance(entity, WorkflowBase) or isinstance(entity, LaunchPlan): get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity, options=options) diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index 72dca288eb..2c7ef7ec9d 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -116,9 +116,9 @@ def test_package_with_fast_registration_and_envvars(): # Uncompress flyte-package.tgz tarfile.open("flyte-package.tgz", "r:gz").extractall() - # Load the proto message from file 3_core.sample.sum_1.pb + # Load the proto message from file 4_core.sample.sum_1.pb task_spec = task_pb2.TaskSpec() - task_spec.ParseFromString(open("3_core.sample.sum_1.pb", "rb").read()) + task_spec.ParseFromString(open("4_core.sample.sum_1.pb", "rb").read()) assert task_spec.template.container.env[0].key == "abc" assert task_spec.template.container.env[0].value == "42" @@ -148,9 +148,9 @@ def test_package_with_fast_registration_and_envvars(): tarfile.open("flyte-package.tgz", "r:gz").extractall() - # Load the proto message from file 3_core.sample.sum_1.pb + # Load the proto message from file 4_core.sample.sum_1.pb task_spec = task_pb2.TaskSpec() - task_spec.ParseFromString(open("3_core.sample.sum_1.pb", "rb").read()) + task_spec.ParseFromString(open("4_core.sample.sum_1.pb", "rb").read()) assert task_spec.template.container.env[0].key == "k1" assert task_spec.template.container.env[0].value == "v1" diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 381f456bdb..7b7bd22674 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -513,8 +513,19 @@ def my_wf(a: str) -> str: image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) - wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) - assert wf_spec.template.nodes[0].metadata.cache_serializable + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert not task_spec.template.metadata.discoverable + assert task_spec.template.metadata.discovery_version != "foo" + assert not task_spec.template.metadata.cache_serializable + + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.cacheable assert wf_spec.template.nodes[0].metadata.cache_version == "foo" + assert wf_spec.template.nodes[0].metadata.cache_serializable + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.metadata.discoverable + assert task_spec.template.metadata.discovery_version == "foo" + assert task_spec.template.metadata.cache_serializable