From 770c65076fd9a9edd97ef27e951100a80fa7cf3f Mon Sep 17 00:00:00 2001 From: Biao He Date: Tue, 3 Dec 2024 23:44:05 -0800 Subject: [PATCH 1/6] Reorder serialization order and override cache Signed-off-by: Stefan He --- dev-requirements.in | 6 +++--- flytekit/core/node.py | 6 ++++++ flytekit/tools/serialize_helpers.py | 7 ++++++- pyproject.toml | 2 +- .../flytekit/unit/core/test_node_creation.py | 20 +++++++++++++++---- 5 files changed, 32 insertions(+), 9 deletions(-) diff --git a/dev-requirements.in b/dev-requirements.in index 20aba11e9d..cb11137a47 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -23,13 +23,13 @@ setuptools_scm pytest-icdiff # Tensorflow is not available for python 3.12 yet: https://github.com/tensorflow/tensorflow/issues/62003 -tensorflow<=2.15.1; python_version<'3.12' +#tensorflow<=2.15.1; python_version<'3.12' # Newer versions of torch bring in nvidia dependencies that are not present in windows, so # we put this constraint while we do not have per-environment requirements files -torch<=1.12.1; python_version<'3.11' +# torch<=1.12.1; python_version<'3.11' # pytorch 2 supports python 3.11 # pytorch 2 does not support 3.12 yet: https://github.com/pytorch/pytorch/issues/110436 -torch; python_version<'3.12' +# torch; python_version<'3.12' pydantic # TODO: Currently, the python-magic library causes build errors on Windows due to its dependency on DLLs for libmagic. diff --git a/flytekit/core/node.py b/flytekit/core/node.py index ea089c6fd3..2b6b0275e1 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -212,14 +212,20 @@ def with_overrides( if cache is not None: assert_not_promise(cache, "cache") self._metadata._cacheable = cache + if self.run_entity and self.run_entity.metadata is not None: + self.run_entity.metadata.cache = cache if cache_version is not None: assert_not_promise(cache_version, "cache_version") self._metadata._cache_version = cache_version + if self.run_entity and self.run_entity.metadata is not None: + self.run_entity.metadata.cache_version = cache_version if cache_serialize is not None: assert_not_promise(cache_serialize, "cache_serialize") self._metadata._cache_serializable = cache_serialize + if self.run_entity and self.run_entity.metadata is not None: + self.run_entity.metadata.cache_serialize = cache_serialize return self diff --git a/flytekit/tools/serialize_helpers.py b/flytekit/tools/serialize_helpers.py index 8d4cfcb99c..0231b77374 100644 --- a/flytekit/tools/serialize_helpers.py +++ b/flytekit/tools/serialize_helpers.py @@ -50,9 +50,14 @@ def get_registrable_entities( that are not known to Admin """ new_api_serializable_entities = OrderedDict() + + # Sort entities to process workflows and launch plans before tasks # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan # object, which gets added to the FlyteEntities.entities list, which we're iterating over. - for entity in flyte_context.FlyteEntities.entities.copy(): + sorted_entities = sorted( + flyte_context.FlyteEntities.entities.copy(), key=lambda x: 0 if isinstance(x, (WorkflowBase, LaunchPlan)) else 1 + ) + for entity in sorted_entities: if isinstance(entity, PythonTask) or isinstance(entity, WorkflowBase) or isinstance(entity, LaunchPlan): get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity, options=options) diff --git a/pyproject.toml b/pyproject.toml index 73d228d8af..a39f787bae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.13.9", + "flyteidl>=1.13", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 381f456bdb..16e6875826 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -2,12 +2,14 @@ import typing from collections import OrderedDict from dataclasses import dataclass +from pathlib import Path import pytest import flytekit.configuration from flytekit import Resources, map_task from flytekit.configuration import Image, ImageConfig +from flytekit.core.base_task import PythonTask from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.node_creation import create_node from flytekit.core.task import task @@ -17,6 +19,7 @@ from flytekit.image_spec.image_spec import ImageBuildEngine from flytekit.models import literals as _literal_models from flytekit.models.task import Resources as _resources_models +from flytekit.tools.repo import load_packages_and_modules from flytekit.tools.translator import get_serializable @@ -513,8 +516,17 @@ def my_wf(a: str) -> str: image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) - wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) - assert wf_spec.template.nodes[0].metadata.cache_serializable - assert wf_spec.template.nodes[0].metadata.cacheable - assert wf_spec.template.nodes[0].metadata.cache_version == "foo" + registrable_entities = load_packages_and_modules( + ss=serialization_settings, + project_root=Path(__file__).parent.parent.parent.parent, + pkgs_or_mods=[str(__file__)], + ) + + # Find our specific task by name + for entity in registrable_entities: + if (isinstance(entity, PythonTask)): + assert entity is not None + assert entity.template.metadata.discoverable + assert entity.template.metadata.cache_version == "foo" + assert entity.template.metadata.cache_serializable From e742b116462f246759dd766bb5290066da33d2d4 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Tue, 3 Dec 2024 23:52:56 -0800 Subject: [PATCH 2/6] revert unnecessary changes --- dev-requirements.in | 6 +++--- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev-requirements.in b/dev-requirements.in index cb11137a47..20aba11e9d 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -23,13 +23,13 @@ setuptools_scm pytest-icdiff # Tensorflow is not available for python 3.12 yet: https://github.com/tensorflow/tensorflow/issues/62003 -#tensorflow<=2.15.1; python_version<'3.12' +tensorflow<=2.15.1; python_version<'3.12' # Newer versions of torch bring in nvidia dependencies that are not present in windows, so # we put this constraint while we do not have per-environment requirements files -# torch<=1.12.1; python_version<'3.11' +torch<=1.12.1; python_version<'3.11' # pytorch 2 supports python 3.11 # pytorch 2 does not support 3.12 yet: https://github.com/pytorch/pytorch/issues/110436 -# torch; python_version<'3.12' +torch; python_version<'3.12' pydantic # TODO: Currently, the python-magic library causes build errors on Windows due to its dependency on DLLs for libmagic. diff --git a/pyproject.toml b/pyproject.toml index a39f787bae..73d228d8af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.13", + "flyteidl>=1.13.9", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", From 9214e1035767e7d4235a09ff3a6e9908245e26dc Mon Sep 17 00:00:00 2001 From: Stefan He Date: Wed, 4 Dec 2024 17:12:08 -0800 Subject: [PATCH 3/6] Add unit test --- .../flytekit/unit/core/test_node_creation.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 16e6875826..7b7bd22674 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -2,14 +2,12 @@ import typing from collections import OrderedDict from dataclasses import dataclass -from pathlib import Path import pytest import flytekit.configuration from flytekit import Resources, map_task from flytekit.configuration import Image, ImageConfig -from flytekit.core.base_task import PythonTask from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.node_creation import create_node from flytekit.core.task import task @@ -19,7 +17,6 @@ from flytekit.image_spec.image_spec import ImageBuildEngine from flytekit.models import literals as _literal_models from flytekit.models.task import Resources as _resources_models -from flytekit.tools.repo import load_packages_and_modules from flytekit.tools.translator import get_serializable @@ -517,16 +514,18 @@ def my_wf(a: str) -> str: env={}, ) - registrable_entities = load_packages_and_modules( - ss=serialization_settings, - project_root=Path(__file__).parent.parent.parent.parent, - pkgs_or_mods=[str(__file__)], - ) + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert not task_spec.template.metadata.discoverable + assert task_spec.template.metadata.discovery_version != "foo" + assert not task_spec.template.metadata.cache_serializable - # Find our specific task by name - for entity in registrable_entities: - if (isinstance(entity, PythonTask)): - assert entity is not None - assert entity.template.metadata.discoverable - assert entity.template.metadata.cache_version == "foo" - assert entity.template.metadata.cache_serializable + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].metadata.cacheable + assert wf_spec.template.nodes[0].metadata.cache_version == "foo" + assert wf_spec.template.nodes[0].metadata.cache_serializable + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.metadata.discoverable + assert task_spec.template.metadata.discovery_version == "foo" + assert task_spec.template.metadata.cache_serializable From 06967918429b6df4ad5b63db6d61099d64188ed9 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Wed, 4 Dec 2024 17:21:30 -0800 Subject: [PATCH 4/6] Fix Unit test --- tests/flytekit/unit/cli/pyflyte/test_package.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index 72dca288eb..2c7ef7ec9d 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -116,9 +116,9 @@ def test_package_with_fast_registration_and_envvars(): # Uncompress flyte-package.tgz tarfile.open("flyte-package.tgz", "r:gz").extractall() - # Load the proto message from file 3_core.sample.sum_1.pb + # Load the proto message from file 4_core.sample.sum_1.pb task_spec = task_pb2.TaskSpec() - task_spec.ParseFromString(open("3_core.sample.sum_1.pb", "rb").read()) + task_spec.ParseFromString(open("4_core.sample.sum_1.pb", "rb").read()) assert task_spec.template.container.env[0].key == "abc" assert task_spec.template.container.env[0].value == "42" @@ -148,9 +148,9 @@ def test_package_with_fast_registration_and_envvars(): tarfile.open("flyte-package.tgz", "r:gz").extractall() - # Load the proto message from file 3_core.sample.sum_1.pb + # Load the proto message from file 4_core.sample.sum_1.pb task_spec = task_pb2.TaskSpec() - task_spec.ParseFromString(open("3_core.sample.sum_1.pb", "rb").read()) + task_spec.ParseFromString(open("4_core.sample.sum_1.pb", "rb").read()) assert task_spec.template.container.env[0].key == "k1" assert task_spec.template.container.env[0].value == "v1" From f30ba9d75f4cabb8541f022f196db01726786fa9 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Wed, 4 Dec 2024 23:03:56 -0800 Subject: [PATCH 5/6] Fix test --- flytekit/core/node.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 2b6b0275e1..aedd7c8c6c 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -194,6 +194,8 @@ def with_overrides( if name is not None: self._metadata._name = name + if self.run_entity and hasattr(self.run_entity, "metadata"): + self.run_entity.metadata.name = name if task_config is not None: logger.warning("This override is beta. We may want to revisit this in the future.") @@ -212,19 +214,19 @@ def with_overrides( if cache is not None: assert_not_promise(cache, "cache") self._metadata._cacheable = cache - if self.run_entity and self.run_entity.metadata is not None: + if getattr(self.run_entity, "metadata", None): self.run_entity.metadata.cache = cache if cache_version is not None: assert_not_promise(cache_version, "cache_version") self._metadata._cache_version = cache_version - if self.run_entity and self.run_entity.metadata is not None: + if getattr(self.run_entity, "metadata", None): self.run_entity.metadata.cache_version = cache_version if cache_serialize is not None: assert_not_promise(cache_serialize, "cache_serialize") self._metadata._cache_serializable = cache_serialize - if self.run_entity and self.run_entity.metadata is not None: + if getattr(self.run_entity, "metadata", None): self.run_entity.metadata.cache_serialize = cache_serialize return self From 698d62bc4dba5758d5e1153e1d5d9c7b5127744a Mon Sep 17 00:00:00 2001 From: Stefan He Date: Thu, 5 Dec 2024 15:34:06 -0800 Subject: [PATCH 6/6] fix --- flytekit/core/node.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index aedd7c8c6c..73b19526e6 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -214,19 +214,19 @@ def with_overrides( if cache is not None: assert_not_promise(cache, "cache") self._metadata._cacheable = cache - if getattr(self.run_entity, "metadata", None): + if hasattr(self.run_entity, "metadata") and self.run_entity.metadata is not None: self.run_entity.metadata.cache = cache if cache_version is not None: assert_not_promise(cache_version, "cache_version") self._metadata._cache_version = cache_version - if getattr(self.run_entity, "metadata", None): + if hasattr(self.run_entity, "metadata") and self.run_entity.metadata is not None: self.run_entity.metadata.cache_version = cache_version if cache_serialize is not None: assert_not_promise(cache_serialize, "cache_serialize") self._metadata._cache_serializable = cache_serialize - if getattr(self.run_entity, "metadata", None): + if hasattr(self.run_entity, "metadata") and self.run_entity.metadata is not None: self.run_entity.metadata.cache_serialize = cache_serialize return self