diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 6dc4bb4459..3c9f4039e4 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -57,10 +57,12 @@ jobs: echo "No tagged version found, exiting" exit 1 fi - LINK="https://pypi.org/project/flytekitplugins-pod/${VERSION}" + sleep 300 + LINK="https://pypi.org/project/flytekitplugins-pod/${VERSION}/" for i in {1..60}; do - if curl -L -I -s -f ${LINK} >/dev/null; then - echo "Found pypi" + result=$(curl -L -I -s -f ${LINK}) + if [ $? -eq 0 ]; then + echo "Found pypi for $LINK" exit 0 else echo "Did not find - Retrying in 10 seconds..." diff --git a/Dockerfile b/Dockerfile index 87515506dd..0f95e517ba 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,19 +10,30 @@ ENV PYTHONPATH /root ARG VERSION ARG DOCKER_IMAGE -RUN apt-get update && apt-get install build-essential -y +# Note: Pod tasks should be exposed in the default image +# Note: Some packages will create config files under /home by default, so we need to make sure it's writable +# Note: There are use cases that require reading and writing files under /tmp, so we need to change its permissions. -# Pod tasks should be exposed in the default image -RUN pip install --no-cache-dir -U flytekit==$VERSION \ - flytekitplugins-pod==$VERSION \ - flytekitplugins-deck-standard==$VERSION \ - scikit-learn \ - && : +# Run a series of commands to set up the environment: +# 1. Update and install dependencies. +# 2. Install Flytekit and its plugins. +# 3. Clean up the apt cache to reduce image size. Reference: https://gist.github.com/marvell/7c812736565928e602c4 +# 4. Create a non-root user 'flytekit' and set appropriate permissions for directories. +RUN apt-get update && apt-get install build-essential -y \ + && pip install --no-cache-dir -U flytekit==$VERSION \ + flytekitplugins-pod==$VERSION \ + flytekitplugins-deck-standard==$VERSION \ + scikit-learn \ + && apt-get clean autoclean \ + && apt-get autoremove --yes \ + && rm -rf /var/lib/{apt,dpkg,cache,log}/ \ + && useradd -u 1000 flytekit \ + && chown flytekit: /root \ + && chown flytekit: /home \ + && chown -R flytekit: /tmp \ + && chmod 755 /tmp \ + && : -RUN useradd -u 1000 flytekit -RUN chown flytekit: /root -# Some packages will create config file under /home by default, so we need to make sure it's writable -RUN chown flytekit: /home USER flytekit ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE" diff --git a/Dockerfile.dev b/Dockerfile.dev index 032dd29429..26ffb88ffd 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -15,21 +15,34 @@ WORKDIR /root ARG VERSION -RUN apt-get update && apt-get install build-essential vim libmagic1 -y - COPY . /flytekit -# Pod tasks should be exposed in the default image -RUN pip install -e /flytekit -RUN pip install -e /flytekit/plugins/flytekit-k8s-pod -RUN pip install -e /flytekit/plugins/flytekit-deck-standard -RUN pip install -e /flytekit/plugins/flytekit-flyin -RUN pip install scikit-learn +# Note: Pod tasks should be exposed in the default image +# Note: Some packages will create config files under /home by default, so we need to make sure it's writable +# Note: There are use cases that require reading and writing files under /tmp, so we need to change its permissions. + +# Run a series of commands to set up the environment: +# 1. Update and install dependencies. +# 2. Install Flytekit and its plugins. +# 3. Clean up the apt cache to reduce image size. Reference: https://gist.github.com/marvell/7c812736565928e602c4 +# 4. Create a non-root user 'flytekit' and set appropriate permissions for directories. +RUN apt-get update && apt-get install build-essential vim libmagic1 -y \ + && pip install --no-cache-dir -e /flytekit \ + && pip install --no-cache-dir -e /flytekit/plugins/flytekit-k8s-pod \ + && pip install --no-cache-dir -e /flytekit/plugins/flytekit-deck-standard \ + && pip install --no-cache-dir -e /flytekit/plugins/flytekit-flyin \ + && pip install --no-cache-dir scikit-learn \ + && apt-get clean autoclean \ + && apt-get autoremove --yes \ + && rm -rf /var/lib/{apt,dpkg,cache,log}/ \ + && useradd -u 1000 flytekit \ + && chown flytekit: /root \ + && chown flytekit: /home \ + && chown -R flytekit: /tmp \ + && chmod 755 /tmp \ + && : ENV PYTHONPATH "/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:" -RUN useradd -u 1000 flytekit -RUN chown flytekit: /root -# Some packages will create config file under /home by default, so we need to make sure it's writable -RUN chown flytekit: /home +# Switch to the 'flytekit' user for better security. USER flytekit diff --git a/Makefile b/Makefile index 9e7c8866e3..0ae94f5604 100644 --- a/Makefile +++ b/Makefile @@ -35,9 +35,9 @@ fmt: lint: ## Run linters mypy flytekit/core mypy flytekit/types - # allow-empty-bodies: Allow empty body in function. - # disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked". - # Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass. +# allow-empty-bodies: Allow empty body in function. +# disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked". +# Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass. mypy --allow-empty-bodies --disable-error-code="annotation-unchecked" tests/flytekit/unit/core pre-commit run --all-files diff --git a/dev-requirements.txt b/dev-requirements.txt index 8a00c9588f..d10b386a7f 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -551,7 +551,6 @@ werkzeug==3.0.1 wheel==0.41.3 # via # astunparse - # flytekit # tensorboard wrapt==1.15.0 # via diff --git a/doc-requirements.txt b/doc-requirements.txt index 5336cfe0d3..a9449aee2d 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -1416,7 +1416,6 @@ werkzeug==2.3.7 wheel==0.41.2 # via # astunparse - # flytekit # tensorboard whylabs-client==0.5.7 # via diff --git a/flytekit/core/checkpointer.py b/flytekit/core/checkpointer.py index ee111979e7..7ac649a487 100644 --- a/flytekit/core/checkpointer.py +++ b/flytekit/core/checkpointer.py @@ -102,7 +102,7 @@ def restore(self, path: typing.Optional[typing.Union[Path, str]] = None) -> typi if path is None: p = Path(self._td.name) path = p.joinpath(self.SRC_LOCAL_FOLDER) - path.mkdir() + path.mkdir(exist_ok=True) elif isinstance(path, str): path = Path(path) diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index de85c0be97..833c7d8562 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -107,7 +107,7 @@ def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder: return self def build(self) -> ExecutionParameters: - if not isinstance(self.working_dir, utils.AutoDeletingTempDir): + if self.working_dir and not isinstance(self.working_dir, utils.AutoDeletingTempDir): pathlib.Path(typing.cast(str, self.working_dir)).mkdir(parents=True, exist_ok=True) return ExecutionParameters( execution_date=self.execution_date, diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 3bcf406d42..123cb4a0ef 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -757,6 +757,83 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]: raise ValueError(f"Transformer {self} cannot reverse {literal_type}") +class EnumTransformer(TypeTransformer[enum.Enum]): + """ + Enables converting a python type enum.Enum to LiteralType.EnumType + """ + + def __init__(self): + super().__init__(name="DefaultEnumTransformer", t=enum.Enum) + + def get_literal_type(self, t: Type[T]) -> LiteralType: + if is_annotated(t): + raise ValueError( + f"Flytekit does not currently have support \ + for FlyteAnnotations applied to enums. {t} cannot be \ + parsed." + ) + + values = [v.value for v in t] # type: ignore + if not isinstance(values[0], str): + raise TypeTransformerFailedError("Only EnumTypes with value of string are supported") + return LiteralType(enum_type=_core_types.EnumType(values=values)) + + def to_literal( + self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType + ) -> Literal: + if type(python_val).__class__ != enum.EnumMeta: + raise TypeTransformerFailedError("Expected an enum") + if type(python_val.value) != str: + raise TypeTransformerFailedError("Only string-valued enums are supportedd") + + return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + return expected_python_type(lv.scalar.primitive.string_value) # type: ignore + + def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]: + if literal_type.enum_type: + return enum.Enum("DynamicEnum", {f"{i}": i for i in literal_type.enum_type.values}) # type: ignore + raise ValueError(f"Enum transformer cannot reverse {literal_type}") + + +def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any): + attribute_list = [] + for property_key, property_val in schema["properties"].items(): + if property_val.get("anyOf"): + property_type = property_val["anyOf"][0]["type"] + elif property_val.get("enum"): + property_type = "enum" + else: + property_type = property_val["type"] + # Handle list + if property_type == "array": + attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore + # Handle dataclass and dict + elif property_type == "object": + if property_val.get("anyOf"): + sub_schemea = property_val["anyOf"][0] + sub_schemea_name = sub_schemea["title"] + attribute_list.append( + (property_key, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name)) + ) + elif property_val.get("additionalProperties"): + attribute_list.append( + (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore + ) + else: + sub_schemea_name = property_val["title"] + attribute_list.append( + (property_key, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name)) + ) + elif property_type == "enum": + attribute_list.append([property_key, str]) # type: ignore + # Handle int, float, bool or str + else: + attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore + return attribute_list + + class TypeEngine(typing.Generic[T]): """ Core Extensible TypeEngine of Flytekit. This should be used to extend the capabilities of FlyteKits type system. @@ -767,6 +844,7 @@ class TypeEngine(typing.Generic[T]): _REGISTRY: typing.Dict[type, TypeTransformer[T]] = {} _RESTRICTED_TYPES: typing.List[type] = [] _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore + _ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore has_lazy_import = False @classmethod @@ -823,6 +901,9 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: Walk the inheritance hierarchy of v and find a transformer that matches the first base class. This is potentially non-deterministic - will depend on the registration pattern. + Special case: + If v inherits from Enum, use the Enum transformer even if Enum is not the first base class. + TODO lets make this deterministic by using an ordered dict Step 5: @@ -838,6 +919,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: python_type = args[0] + # Step 2 # this makes sure that if it's a list/dict of annotated types, we hit the unwrapping code in step 2 # see test_list_of_annotated in test_structured_dataset.py if ( @@ -849,7 +931,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: ) and python_type in cls._REGISTRY: return cls._REGISTRY[python_type] - # Step 2 + # Step 3 if hasattr(python_type, "__origin__"): # Handling of annotated generics, eg: # Annotated[typing.List[int], 'foo'] @@ -861,9 +943,13 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: raise ValueError(f"Generic Type {python_type.__origin__} not supported currently in Flytekit.") - # Step 3 + # Step 4 # To facilitate cases where users may specify one transformer for multiple types that all inherit from one # parent. + if inspect.isclass(python_type) and issubclass(python_type, enum.Enum): + # Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used. + return cls._ENUM_TRANSFORMER + for base_type in cls._REGISTRY.keys(): if base_type is None: continue # None is actually one of the keys, but isinstance/issubclass doesn't work on it @@ -877,11 +963,11 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: # is the case for one of the restricted types, namely NamedTuple. logger.debug(f"Invalid base type {base_type} in call to isinstance", exc_info=True) - # Step 4 + # Step 5 if dataclasses.is_dataclass(python_type): return cls._DATACLASS_TRANSFORMER - # Step 5 + # Step 6 display_pickle_warning(str(python_type)) from flytekit.types.pickle.pickle import FlytePickleTransformer @@ -1607,83 +1693,6 @@ def to_python_value( return open(local_path, "rb") -class EnumTransformer(TypeTransformer[enum.Enum]): - """ - Enables converting a python type enum.Enum to LiteralType.EnumType - """ - - def __init__(self): - super().__init__(name="DefaultEnumTransformer", t=enum.Enum) - - def get_literal_type(self, t: Type[T]) -> LiteralType: - if is_annotated(t): - raise ValueError( - f"Flytekit does not currently have support \ - for FlyteAnnotations applied to enums. {t} cannot be \ - parsed." - ) - - values = [v.value for v in t] # type: ignore - if not isinstance(values[0], str): - raise TypeTransformerFailedError("Only EnumTypes with value of string are supported") - return LiteralType(enum_type=_core_types.EnumType(values=values)) - - def to_literal( - self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType - ) -> Literal: - if type(python_val).__class__ != enum.EnumMeta: - raise TypeTransformerFailedError("Expected an enum") - if type(python_val.value) != str: - raise TypeTransformerFailedError("Only string-valued enums are supportedd") - - return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore - - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: - return expected_python_type(lv.scalar.primitive.string_value) # type: ignore - - def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]: - if literal_type.enum_type: - return enum.Enum("DynamicEnum", {f"{i}": i for i in literal_type.enum_type.values}) # type: ignore - raise ValueError(f"Enum transformer cannot reverse {literal_type}") - - -def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any): - attribute_list = [] - for property_key, property_val in schema["properties"].items(): - if property_val.get("anyOf"): - property_type = property_val["anyOf"][0]["type"] - elif property_val.get("enum"): - property_type = "enum" - else: - property_type = property_val["type"] - # Handle list - if property_type == "array": - attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore - # Handle dataclass and dict - elif property_type == "object": - if property_val.get("anyOf"): - sub_schemea = property_val["anyOf"][0] - sub_schemea_name = sub_schemea["title"] - attribute_list.append( - (property_key, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name)) - ) - elif property_val.get("additionalProperties"): - attribute_list.append( - (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore - ) - else: - sub_schemea_name = property_val["title"] - attribute_list.append( - (property_key, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name)) - ) - elif property_type == "enum": - attribute_list.append([property_key, str]) # type: ignore - # Handle int, float, bool or str - else: - attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore - return attribute_list - - def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typing.Any): attribute_list = [] for property_key, property_val in schema[schema_name]["properties"].items(): diff --git a/plugins/flytekit-airflow/dev-requirements.in b/plugins/flytekit-airflow/dev-requirements.in new file mode 100644 index 0000000000..8ee20b47d1 --- /dev/null +++ b/plugins/flytekit-airflow/dev-requirements.in @@ -0,0 +1 @@ +apache-airflow-providers-apache-beam[google] diff --git a/plugins/flytekit-airflow/dev-requirements.txt b/plugins/flytekit-airflow/dev-requirements.txt new file mode 100644 index 0000000000..a0433fc2cb --- /dev/null +++ b/plugins/flytekit-airflow/dev-requirements.txt @@ -0,0 +1,980 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile dev-requirements.in +# +aiofiles==23.2.1 + # via gcloud-aio-storage +aiohttp==3.9.0 + # via + # apache-airflow-providers-http + # gcloud-aio-auth + # gcsfs +aiosignal==1.3.1 + # via aiohttp +alembic==1.12.1 + # via + # apache-airflow + # sqlalchemy-spanner +annotated-types==0.6.0 + # via pydantic +anyio==4.0.0 + # via httpx +apache-airflow==2.7.3 + # via + # apache-airflow-providers-apache-beam + # apache-airflow-providers-common-sql + # apache-airflow-providers-ftp + # apache-airflow-providers-google + # apache-airflow-providers-http + # apache-airflow-providers-imap + # apache-airflow-providers-sqlite +apache-airflow-providers-apache-beam[google]==5.3.0 + # via -r dev-requirements.in +apache-airflow-providers-common-sql==1.8.0 + # via + # apache-airflow + # apache-airflow-providers-google + # apache-airflow-providers-sqlite +apache-airflow-providers-ftp==3.6.0 + # via apache-airflow +apache-airflow-providers-google==10.11.0 + # via apache-airflow-providers-apache-beam +apache-airflow-providers-http==4.6.0 + # via apache-airflow +apache-airflow-providers-imap==3.4.0 + # via apache-airflow +apache-airflow-providers-sqlite==3.5.0 + # via apache-airflow +apache-beam[gcp]==2.51.0 + # via apache-airflow-providers-apache-beam +apispec[yaml]==6.3.0 + # via + # apispec + # flask-appbuilder +argcomplete==3.1.4 + # via apache-airflow +asgiref==3.7.2 + # via + # apache-airflow + # apache-airflow-providers-google + # apache-airflow-providers-http +async-timeout==4.0.3 + # via aiohttp +attrs==23.1.0 + # via + # aiohttp + # apache-airflow + # cattrs + # jsonschema + # looker-sdk + # referencing +babel==2.13.1 + # via flask-babel +backoff==2.2.1 + # via + # gcloud-aio-auth + # opentelemetry-exporter-otlp-proto-common + # opentelemetry-exporter-otlp-proto-grpc + # opentelemetry-exporter-otlp-proto-http +blinker==1.7.0 + # via apache-airflow +cachelib==0.9.0 + # via + # flask-caching + # flask-session +cachetools==5.3.2 + # via + # apache-beam + # google-auth +cattrs==23.1.2 + # via + # apache-airflow + # looker-sdk +certifi==2023.7.22 + # via + # httpcore + # httpx + # requests +cffi==1.16.0 + # via cryptography +chardet==5.2.0 + # via gcloud-aio-auth +charset-normalizer==3.3.2 + # via requests +click==8.1.7 + # via + # clickclick + # flask + # flask-appbuilder +clickclick==20.10.2 + # via connexion +cloudpickle==2.2.1 + # via apache-beam +colorama==0.4.6 + # via flask-appbuilder +colorlog==4.8.0 + # via apache-airflow +configupdater==3.1.1 + # via apache-airflow +connexion[flask]==2.14.2 + # via + # apache-airflow + # connexion +crcmod==1.7 + # via apache-beam +cron-descriptor==1.4.0 + # via apache-airflow +croniter==2.0.1 + # via apache-airflow +cryptography==41.0.6 + # via + # apache-airflow + # gcloud-aio-auth + # pyopenssl +db-dtypes==1.1.1 + # via pandas-gbq +decorator==5.1.1 + # via gcsfs +deprecated==1.2.14 + # via + # apache-airflow + # limits + # opentelemetry-api + # opentelemetry-exporter-otlp-proto-grpc + # opentelemetry-exporter-otlp-proto-http +dill==0.3.1.1 + # via + # apache-airflow + # apache-beam +dnspython==2.4.2 + # via + # email-validator + # pymongo +docopt==0.6.2 + # via hdfs +docutils==0.20.1 + # via python-daemon +email-validator==1.3.1 + # via flask-appbuilder +exceptiongroup==1.1.3 + # via + # anyio + # cattrs +fastavro==1.9.0 + # via apache-beam +fasteners==0.19 + # via + # apache-beam + # google-apitools +flask==2.2.5 + # via + # apache-airflow + # connexion + # flask-appbuilder + # flask-babel + # flask-caching + # flask-jwt-extended + # flask-limiter + # flask-login + # flask-session + # flask-sqlalchemy + # flask-wtf +flask-appbuilder==4.3.6 + # via apache-airflow +flask-babel==2.0.0 + # via flask-appbuilder +flask-caching==2.1.0 + # via apache-airflow +flask-jwt-extended==4.5.3 + # via flask-appbuilder +flask-limiter==3.5.0 + # via flask-appbuilder +flask-login==0.6.3 + # via + # apache-airflow + # flask-appbuilder +flask-session==0.5.0 + # via apache-airflow +flask-sqlalchemy==2.5.1 + # via flask-appbuilder +flask-wtf==1.2.1 + # via + # apache-airflow + # flask-appbuilder +frozenlist==1.4.0 + # via + # aiohttp + # aiosignal +fsspec==2023.10.0 + # via gcsfs +gcloud-aio-auth==4.2.3 + # via + # apache-airflow-providers-google + # gcloud-aio-bigquery + # gcloud-aio-storage +gcloud-aio-bigquery==7.0.0 + # via apache-airflow-providers-google +gcloud-aio-storage==9.0.0 + # via apache-airflow-providers-google +gcsfs==2023.10.0 + # via apache-airflow-providers-google +google-ads==22.1.0 + # via apache-airflow-providers-google +google-api-core[grpc]==2.13.0 + # via + # apache-airflow-providers-google + # apache-beam + # google-ads + # google-api-python-client + # google-cloud-aiplatform + # google-cloud-appengine-logging + # google-cloud-automl + # google-cloud-batch + # google-cloud-bigquery + # google-cloud-bigquery-datatransfer + # google-cloud-bigquery-storage + # google-cloud-bigtable + # google-cloud-build + # google-cloud-compute + # google-cloud-container + # google-cloud-core + # google-cloud-datacatalog + # google-cloud-dataflow-client + # google-cloud-dataform + # google-cloud-dataplex + # google-cloud-dataproc + # google-cloud-dataproc-metastore + # google-cloud-datastore + # google-cloud-dlp + # google-cloud-kms + # google-cloud-language + # google-cloud-logging + # google-cloud-memcache + # google-cloud-monitoring + # google-cloud-orchestration-airflow + # google-cloud-os-login + # google-cloud-pubsub + # google-cloud-pubsublite + # google-cloud-recommendations-ai + # google-cloud-redis + # google-cloud-resource-manager + # google-cloud-run + # google-cloud-secret-manager + # google-cloud-spanner + # google-cloud-speech + # google-cloud-storage + # google-cloud-storage-transfer + # google-cloud-tasks + # google-cloud-texttospeech + # google-cloud-translate + # google-cloud-videointelligence + # google-cloud-vision + # google-cloud-workflows + # pandas-gbq + # sqlalchemy-bigquery +google-api-python-client==2.107.0 + # via apache-airflow-providers-google +google-apitools==0.5.31 + # via apache-beam +google-auth==2.23.4 + # via + # apache-airflow-providers-google + # apache-beam + # gcsfs + # google-api-core + # google-api-python-client + # google-auth-httplib2 + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # pandas-gbq + # pydata-google-auth + # sqlalchemy-bigquery +google-auth-httplib2==0.1.1 + # via + # apache-airflow-providers-google + # apache-beam + # google-api-python-client +google-auth-oauthlib==1.1.0 + # via + # gcsfs + # google-ads + # pandas-gbq + # pydata-google-auth +google-cloud-aiplatform==1.36.1 + # via + # apache-airflow-providers-google + # apache-beam +google-cloud-appengine-logging==1.3.2 + # via google-cloud-logging +google-cloud-audit-log==0.2.5 + # via google-cloud-logging +google-cloud-automl==2.11.3 + # via apache-airflow-providers-google +google-cloud-batch==0.17.3 + # via apache-airflow-providers-google +google-cloud-bigquery==3.13.0 + # via + # apache-beam + # google-cloud-aiplatform + # pandas-gbq + # sqlalchemy-bigquery +google-cloud-bigquery-datatransfer==3.12.1 + # via apache-airflow-providers-google +google-cloud-bigquery-storage==2.22.0 + # via + # apache-beam + # pandas-gbq +google-cloud-bigtable==2.21.0 + # via + # apache-airflow-providers-google + # apache-beam +google-cloud-build==3.21.0 + # via apache-airflow-providers-google +google-cloud-compute==1.14.1 + # via apache-airflow-providers-google +google-cloud-container==2.33.0 + # via apache-airflow-providers-google +google-cloud-core==2.3.3 + # via + # apache-beam + # google-cloud-bigquery + # google-cloud-bigtable + # google-cloud-datastore + # google-cloud-logging + # google-cloud-spanner + # google-cloud-storage + # google-cloud-translate +google-cloud-datacatalog==3.16.0 + # via apache-airflow-providers-google +google-cloud-dataflow-client==0.8.5 + # via apache-airflow-providers-google +google-cloud-dataform==0.5.4 + # via apache-airflow-providers-google +google-cloud-dataplex==1.8.1 + # via apache-airflow-providers-google +google-cloud-dataproc==5.7.0 + # via apache-airflow-providers-google +google-cloud-dataproc-metastore==1.13.0 + # via apache-airflow-providers-google +google-cloud-datastore==2.18.0 + # via apache-beam +google-cloud-dlp==3.13.0 + # via + # apache-airflow-providers-google + # apache-beam +google-cloud-kms==2.19.2 + # via apache-airflow-providers-google +google-cloud-language==2.11.1 + # via + # apache-airflow-providers-google + # apache-beam +google-cloud-logging==3.8.0 + # via apache-airflow-providers-google +google-cloud-memcache==1.7.3 + # via apache-airflow-providers-google +google-cloud-monitoring==2.16.0 + # via apache-airflow-providers-google +google-cloud-orchestration-airflow==1.9.2 + # via apache-airflow-providers-google +google-cloud-os-login==2.11.0 + # via apache-airflow-providers-google +google-cloud-pubsub==2.18.4 + # via + # apache-airflow-providers-google + # apache-beam + # google-cloud-pubsublite +google-cloud-pubsublite==1.8.3 + # via apache-beam +google-cloud-recommendations-ai==0.10.5 + # via apache-beam +google-cloud-redis==2.13.2 + # via apache-airflow-providers-google +google-cloud-resource-manager==1.10.4 + # via google-cloud-aiplatform +google-cloud-run==0.10.0 + # via apache-airflow-providers-google +google-cloud-secret-manager==2.16.4 + # via apache-airflow-providers-google +google-cloud-spanner==3.40.1 + # via + # apache-airflow-providers-google + # apache-beam + # sqlalchemy-spanner +google-cloud-speech==2.22.0 + # via apache-airflow-providers-google +google-cloud-storage==2.13.0 + # via + # apache-airflow-providers-google + # gcsfs + # google-cloud-aiplatform +google-cloud-storage-transfer==1.9.2 + # via apache-airflow-providers-google +google-cloud-tasks==2.14.2 + # via apache-airflow-providers-google +google-cloud-texttospeech==2.14.2 + # via apache-airflow-providers-google +google-cloud-translate==3.12.1 + # via apache-airflow-providers-google +google-cloud-videointelligence==2.11.4 + # via + # apache-airflow-providers-google + # apache-beam +google-cloud-vision==3.4.5 + # via + # apache-airflow-providers-google + # apache-beam +google-cloud-workflows==1.12.1 + # via apache-airflow-providers-google +google-crc32c==1.5.0 + # via + # google-cloud-storage + # google-resumable-media +google-re2==1.1 + # via apache-airflow +google-resumable-media==2.6.0 + # via + # google-cloud-bigquery + # google-cloud-storage +googleapis-common-protos[grpc]==1.61.0 + # via + # google-ads + # google-api-core + # google-cloud-audit-log + # grpc-google-iam-v1 + # grpcio-status + # opentelemetry-exporter-otlp-proto-grpc + # opentelemetry-exporter-otlp-proto-http +graphviz==0.20.1 + # via apache-airflow +greenlet==3.0.1 + # via sqlalchemy +grpc-google-iam-v1==0.12.6 + # via + # google-cloud-bigtable + # google-cloud-build + # google-cloud-datacatalog + # google-cloud-dataform + # google-cloud-dataplex + # google-cloud-dataproc + # google-cloud-dataproc-metastore + # google-cloud-kms + # google-cloud-logging + # google-cloud-pubsub + # google-cloud-resource-manager + # google-cloud-run + # google-cloud-secret-manager + # google-cloud-spanner + # google-cloud-tasks +grpcio==1.59.2 + # via + # apache-beam + # google-ads + # google-api-core + # google-cloud-bigquery + # google-cloud-pubsub + # google-cloud-pubsublite + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-gcp + # grpcio-status + # opentelemetry-exporter-otlp-proto-grpc +grpcio-gcp==0.2.2 + # via apache-airflow-providers-google +grpcio-status==1.59.2 + # via + # google-ads + # google-api-core + # google-cloud-pubsub + # google-cloud-pubsublite +gunicorn==21.2.0 + # via apache-airflow +h11==0.14.0 + # via httpcore +hdfs==2.7.3 + # via apache-beam +httpcore==1.0.1 + # via httpx +httplib2==0.22.0 + # via + # apache-beam + # google-api-python-client + # google-apitools + # google-auth-httplib2 + # oauth2client +httpx==0.25.1 + # via + # apache-airflow + # apache-airflow-providers-google +idna==3.4 + # via + # anyio + # email-validator + # httpx + # requests + # yarl +importlib-metadata==6.8.0 + # via opentelemetry-api +importlib-resources==6.1.1 + # via limits +inflection==0.5.1 + # via connexion +itsdangerous==2.1.2 + # via + # apache-airflow + # connexion + # flask + # flask-wtf +jinja2==3.1.2 + # via + # apache-airflow + # flask + # flask-babel + # python-nvd3 +js2py==0.74 + # via apache-beam +json-merge-patch==0.2 + # via apache-airflow-providers-google +jsonschema==4.19.2 + # via + # apache-airflow + # connexion + # flask-appbuilder +jsonschema-specifications==2023.7.1 + # via jsonschema +lazy-object-proxy==1.9.0 + # via apache-airflow +limits==3.6.0 + # via flask-limiter +linkify-it-py==2.0.2 + # via apache-airflow +lockfile==0.12.2 + # via + # apache-airflow + # python-daemon +looker-sdk==23.16.0 + # via apache-airflow-providers-google +mako==1.3.0 + # via alembic +markdown==3.5.1 + # via apache-airflow +markdown-it-py==3.0.0 + # via + # apache-airflow + # mdit-py-plugins + # rich +markupsafe==2.1.3 + # via + # apache-airflow + # jinja2 + # mako + # werkzeug + # wtforms +marshmallow==3.20.1 + # via + # flask-appbuilder + # marshmallow-oneofschema + # marshmallow-sqlalchemy +marshmallow-oneofschema==3.0.1 + # via apache-airflow +marshmallow-sqlalchemy==0.26.1 + # via flask-appbuilder +mdit-py-plugins==0.4.0 + # via apache-airflow +mdurl==0.1.2 + # via markdown-it-py +multidict==6.0.4 + # via + # aiohttp + # yarl +numpy==1.24.4 + # via + # apache-beam + # db-dtypes + # pandas + # pandas-gbq + # pyarrow + # shapely +oauth2client==4.1.3 + # via google-apitools +oauthlib==3.2.2 + # via requests-oauthlib +objsize==0.6.1 + # via apache-beam +opentelemetry-api==1.21.0 + # via + # apache-airflow + # opentelemetry-exporter-otlp-proto-grpc + # opentelemetry-exporter-otlp-proto-http + # opentelemetry-sdk +opentelemetry-exporter-otlp==1.21.0 + # via apache-airflow +opentelemetry-exporter-otlp-proto-common==1.21.0 + # via + # opentelemetry-exporter-otlp-proto-grpc + # opentelemetry-exporter-otlp-proto-http +opentelemetry-exporter-otlp-proto-grpc==1.21.0 + # via opentelemetry-exporter-otlp +opentelemetry-exporter-otlp-proto-http==1.21.0 + # via opentelemetry-exporter-otlp +opentelemetry-proto==1.21.0 + # via + # opentelemetry-exporter-otlp-proto-common + # opentelemetry-exporter-otlp-proto-grpc + # opentelemetry-exporter-otlp-proto-http +opentelemetry-sdk==1.21.0 + # via + # opentelemetry-exporter-otlp-proto-grpc + # opentelemetry-exporter-otlp-proto-http +opentelemetry-semantic-conventions==0.42b0 + # via opentelemetry-sdk +ordered-set==4.1.0 + # via flask-limiter +orjson==3.9.10 + # via apache-beam +overrides==6.5.0 + # via google-cloud-pubsublite +packaging==23.2 + # via + # apache-airflow + # apache-beam + # apispec + # connexion + # db-dtypes + # google-cloud-aiplatform + # google-cloud-bigquery + # gunicorn + # limits + # marshmallow + # sqlalchemy-bigquery +pandas==2.0.3 + # via + # apache-airflow-providers-google + # db-dtypes + # pandas-gbq +pandas-gbq==0.19.2 + # via apache-airflow-providers-google +pathspec==0.11.2 + # via apache-airflow +pendulum==2.1.2 + # via apache-airflow +pluggy==1.3.0 + # via apache-airflow +prison==0.2.1 + # via flask-appbuilder +proto-plus==1.22.3 + # via + # apache-airflow-providers-google + # apache-beam + # google-ads + # google-cloud-aiplatform + # google-cloud-appengine-logging + # google-cloud-automl + # google-cloud-batch + # google-cloud-bigquery + # google-cloud-bigquery-datatransfer + # google-cloud-bigquery-storage + # google-cloud-bigtable + # google-cloud-build + # google-cloud-compute + # google-cloud-container + # google-cloud-datacatalog + # google-cloud-dataflow-client + # google-cloud-dataform + # google-cloud-dataplex + # google-cloud-dataproc + # google-cloud-dataproc-metastore + # google-cloud-datastore + # google-cloud-dlp + # google-cloud-kms + # google-cloud-language + # google-cloud-logging + # google-cloud-memcache + # google-cloud-monitoring + # google-cloud-orchestration-airflow + # google-cloud-os-login + # google-cloud-pubsub + # google-cloud-recommendations-ai + # google-cloud-redis + # google-cloud-resource-manager + # google-cloud-run + # google-cloud-secret-manager + # google-cloud-spanner + # google-cloud-speech + # google-cloud-storage-transfer + # google-cloud-tasks + # google-cloud-texttospeech + # google-cloud-translate + # google-cloud-videointelligence + # google-cloud-vision + # google-cloud-workflows +protobuf==4.24.4 + # via + # apache-beam + # google-ads + # google-api-core + # google-cloud-aiplatform + # google-cloud-appengine-logging + # google-cloud-audit-log + # google-cloud-automl + # google-cloud-batch + # google-cloud-bigquery + # google-cloud-bigquery-datatransfer + # google-cloud-bigquery-storage + # google-cloud-bigtable + # google-cloud-build + # google-cloud-compute + # google-cloud-container + # google-cloud-datacatalog + # google-cloud-dataflow-client + # google-cloud-dataform + # google-cloud-dataplex + # google-cloud-dataproc + # google-cloud-dataproc-metastore + # google-cloud-datastore + # google-cloud-dlp + # google-cloud-kms + # google-cloud-language + # google-cloud-logging + # google-cloud-memcache + # google-cloud-monitoring + # google-cloud-orchestration-airflow + # google-cloud-os-login + # google-cloud-pubsub + # google-cloud-recommendations-ai + # google-cloud-redis + # google-cloud-resource-manager + # google-cloud-run + # google-cloud-secret-manager + # google-cloud-spanner + # google-cloud-speech + # google-cloud-storage-transfer + # google-cloud-tasks + # google-cloud-texttospeech + # google-cloud-translate + # google-cloud-videointelligence + # google-cloud-vision + # google-cloud-workflows + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status + # opentelemetry-proto + # proto-plus +psutil==5.9.6 + # via apache-airflow +pyarrow==11.0.0 + # via + # apache-beam + # db-dtypes + # pandas-gbq +pyasn1==0.5.0 + # via + # oauth2client + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via + # gcloud-aio-storage + # google-auth + # oauth2client +pycparser==2.21 + # via cffi +pydantic==2.4.2 + # via apache-airflow +pydantic-core==2.10.1 + # via pydantic +pydata-google-auth==1.8.2 + # via pandas-gbq +pydot==1.4.2 + # via apache-beam +pygments==2.16.1 + # via + # apache-airflow + # rich +pyjsparser==2.7.1 + # via js2py +pyjwt==2.8.0 + # via + # apache-airflow + # flask-appbuilder + # flask-jwt-extended + # gcloud-aio-auth +pymongo==4.6.0 + # via apache-beam +pyopenssl==23.3.0 + # via apache-airflow-providers-google +pyparsing==3.1.1 + # via + # httplib2 + # pydot +python-daemon==3.0.1 + # via apache-airflow +python-dateutil==2.8.2 + # via + # apache-airflow + # apache-beam + # croniter + # flask-appbuilder + # google-cloud-bigquery + # pandas + # pendulum +python-nvd3==0.15.0 + # via apache-airflow +python-slugify==8.0.1 + # via + # apache-airflow + # python-nvd3 +pytz==2023.3.post1 + # via + # apache-beam + # croniter + # flask-babel + # pandas +pytzdata==2020.1 + # via pendulum +pyyaml==6.0.1 + # via + # apispec + # clickclick + # connexion + # google-ads +referencing==0.30.2 + # via + # jsonschema + # jsonschema-specifications +regex==2023.10.3 + # via apache-beam +requests==2.31.0 + # via + # apache-airflow-providers-http + # apache-beam + # connexion + # gcsfs + # google-api-core + # google-cloud-bigquery + # google-cloud-storage + # hdfs + # looker-sdk + # opentelemetry-exporter-otlp-proto-http + # requests-oauthlib + # requests-toolbelt +requests-oauthlib==1.3.1 + # via google-auth-oauthlib +requests-toolbelt==1.0.0 + # via apache-airflow-providers-http +rfc3339-validator==0.1.4 + # via apache-airflow +rich==13.6.0 + # via + # apache-airflow + # flask-limiter + # rich-argparse +rich-argparse==1.4.0 + # via apache-airflow +rpds-py==0.12.0 + # via + # jsonschema + # referencing +rsa==4.9 + # via + # gcloud-aio-storage + # google-auth + # oauth2client +setproctitle==1.3.3 + # via apache-airflow +shapely==2.0.2 + # via google-cloud-aiplatform +six==1.16.0 + # via + # google-apitools + # hdfs + # js2py + # oauth2client + # prison + # python-dateutil + # rfc3339-validator +sniffio==1.3.0 + # via + # anyio + # httpx +sqlalchemy==1.4.50 + # via + # alembic + # apache-airflow + # flask-appbuilder + # flask-sqlalchemy + # marshmallow-sqlalchemy + # sqlalchemy-bigquery + # sqlalchemy-jsonfield + # sqlalchemy-spanner + # sqlalchemy-utils +sqlalchemy-bigquery==1.8.0 + # via apache-airflow-providers-google +sqlalchemy-jsonfield==1.0.1.post0 + # via apache-airflow +sqlalchemy-spanner==1.6.2 + # via apache-airflow-providers-google +sqlalchemy-utils==0.41.1 + # via flask-appbuilder +sqlparse==0.4.4 + # via + # apache-airflow-providers-common-sql + # google-cloud-spanner +tabulate==0.9.0 + # via apache-airflow +tenacity==8.2.3 + # via apache-airflow +termcolor==2.3.0 + # via apache-airflow +text-unidecode==1.3 + # via python-slugify +typing-extensions==4.8.0 + # via + # alembic + # apache-airflow + # apache-beam + # asgiref + # cattrs + # flask-limiter + # limits + # looker-sdk + # opentelemetry-sdk + # pydantic + # pydantic-core +tzdata==2023.3 + # via pandas +tzlocal==5.2 + # via js2py +uc-micro-py==1.0.2 + # via linkify-it-py +unicodecsv==0.14.1 + # via apache-airflow +uritemplate==4.1.1 + # via google-api-python-client +urllib3==2.0.7 + # via requests +werkzeug==2.2.3 + # via + # apache-airflow + # connexion + # flask + # flask-jwt-extended + # flask-login +wrapt==1.15.0 + # via deprecated +wtforms==3.0.1 + # via + # apache-airflow + # flask-appbuilder + # flask-wtf +yarl==1.9.2 + # via aiohttp +zipp==3.17.0 + # via importlib-metadata +zstandard==0.22.0 + # via apache-beam + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/__init__.py b/plugins/flytekit-airflow/flytekitplugins/airflow/__init__.py index 1015066db1..d09cb952cf 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/__init__.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/__init__.py @@ -13,4 +13,4 @@ """ from .agent import AirflowAgent -from .task import AirflowConfig, AirflowTask +from .task import AirflowObj, AirflowTask diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py b/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py index acd7f9d245..22da03bbc8 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py @@ -1,19 +1,18 @@ -import importlib -from dataclasses import dataclass +import asyncio +import typing +from dataclasses import dataclass, field from typing import Optional import cloudpickle import grpc import jsonpickle -from airflow.providers.google.cloud.operators.dataproc import ( - DataprocDeleteClusterOperator, - DataprocJobBaseOperator, - JobStatus, -) +from airflow.exceptions import AirflowException, TaskDeferred +from airflow.models import BaseOperator from airflow.sensors.base import BaseSensorOperator +from airflow.triggers.base import TriggerEvent from airflow.utils.context import Context from flyteidl.admin.agent_pb2 import ( - PERMANENT_FAILURE, + RETRYABLE_FAILURE, RUNNING, SUCCEEDED, CreateTaskResponse, @@ -21,10 +20,10 @@ GetTaskResponse, Resource, ) -from flytekitplugins.airflow.task import AirflowConfig -from google.cloud.exceptions import NotFound +from flytekitplugins.airflow.task import AirflowObj, _get_airflow_instance -from flytekit import FlyteContext, FlyteContextManager, logger +from flytekit import logger +from flytekit.exceptions.user import FlyteUserException from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -32,77 +31,115 @@ @dataclass class ResourceMetadata: - job_id: str - airflow_config: AirflowConfig - - -def _get_airflow_task(ctx: FlyteContext, airflow_config: AirflowConfig): - task_module = importlib.import_module(name=airflow_config.task_module) - task_def = getattr(task_module, airflow_config.task_name) - task_config = airflow_config.task_config + """ + This class is used to store the Airflow task configuration. It is serialized and returned to FlytePropeller. + """ - # Set the GET_ORIGINAL_TASK attribute to True so that task_def will return the original - # airflow task instead of the Flyte task. - ctx.user_space_params.builder().add_attr("GET_ORIGINAL_TASK", True).build() - if issubclass(task_def, DataprocJobBaseOperator): - return task_def(**task_config, asynchronous=True) - return task_def(**task_config) + airflow_operator: AirflowObj + airflow_trigger: AirflowObj = field(default=None) + airflow_trigger_callback: str = field(default=None) + job_id: typing.Optional[str] = field(default=None) class AirflowAgent(AgentBase): + """ + It is used to run Airflow tasks. It is registered as an agent in the AgentRegistry. + There are three kinds of Airflow tasks: AirflowOperator, AirflowSensor, and AirflowHook. + + Sensor is always invoked in get method. Calling get method to check if the certain condition is met. + For example, FileSensor is used to check if the file exists. If file doesn't exist, agent returns + RUNNING status, otherwise, it returns SUCCEEDED status. + + Hook is a high-level interface to an external platform that lets you quickly and easily talk to + them without having to write low-level code that hits their API or uses special libraries. For example, + SlackHook is used to send messages to Slack. Therefore, Hooks are also invoked in get method. + Note: There is no running state for Hook. It is either successful or failed. + + Operator is invoked in create method. Flytekit will always set deferrable to True for Operator. Therefore, + `operator.execute` will always raise TaskDeferred exception after job is submitted. In the get method, + we create a trigger to check if the job is finished. + Note: some of the operators are not deferrable. For example, BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator. + In this case, those operators will be converted to AirflowContainerTask and executed in the pod. + """ + def __init__(self): - super().__init__(task_type="airflow", asynchronous=False) + super().__init__(task_type="airflow", asynchronous=True) - def create( + async def async_create( self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, ) -> CreateTaskResponse: - airflow_config = jsonpickle.decode(task_template.custom.get("task_config_pkl")) - resource_meta = ResourceMetadata(job_id="", airflow_config=airflow_config) + airflow_obj = jsonpickle.decode(task_template.custom["task_config_pkl"]) + airflow_instance = _get_airflow_instance(airflow_obj) + resource_meta = ResourceMetadata(airflow_operator=airflow_obj) - ctx = FlyteContextManager.current_context() - airflow_task = _get_airflow_task(ctx, airflow_config) - if isinstance(airflow_task, DataprocJobBaseOperator): - airflow_task.execute(context=Context()) - resource_meta.job_id = ctx.user_space_params.xcom_data["value"]["resource"] + if isinstance(airflow_instance, BaseOperator) and not isinstance(airflow_instance, BaseSensorOperator): + try: + resource_meta = ResourceMetadata(airflow_operator=airflow_obj) + airflow_instance.execute(context=Context()) + except TaskDeferred as td: + parameters = td.trigger.__dict__.copy() + # Remove parameters that are in the base class + parameters.pop("task_instance", None) + parameters.pop("trigger_id", None) + + resource_meta.airflow_trigger = AirflowObj( + module=td.trigger.__module__, name=td.trigger.__class__.__name__, parameters=parameters + ) + resource_meta.airflow_trigger_callback = td.method_name return CreateTaskResponse(resource_meta=cloudpickle.dumps(resource_meta)) - def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: meta = cloudpickle.loads(resource_meta) - airflow_config = meta.airflow_config - job_id = meta.job_id - task = _get_airflow_task(FlyteContextManager.current_context(), meta.airflow_config) + airflow_operator_instance = _get_airflow_instance(meta.airflow_operator) + airflow_trigger_instance = _get_airflow_instance(meta.airflow_trigger) if meta.airflow_trigger else None + airflow_ctx = Context() + message = None cur_state = RUNNING - if issubclass(type(task), BaseSensorOperator): - if task.poke(context=Context()): + if isinstance(airflow_operator_instance, BaseSensorOperator): + ok = airflow_operator_instance.poke(context=airflow_ctx) + cur_state = SUCCEEDED if ok else RUNNING + elif isinstance(airflow_operator_instance, BaseOperator): + if airflow_trigger_instance: + try: + # Airflow trigger returns immediately when + # 1. Failed to get the task status + # 2. Task succeeded or failed + # succeeded or failed: returns a TriggerEvent with payload + # running: runs forever, so set a default timeout (2 seconds) here. + # failed to get the status: raises AirflowException + event = await asyncio.wait_for(airflow_trigger_instance.run().__anext__(), 2) + try: + # Trigger callback will check the status of the task in the payload, and raise AirflowException if failed. + trigger_callback = getattr(airflow_operator_instance, meta.airflow_trigger_callback) + trigger_callback(context=airflow_ctx, event=typing.cast(TriggerEvent, event).payload) + cur_state = SUCCEEDED + except AirflowException as e: + cur_state = RETRYABLE_FAILURE + message = e.__str__() + except asyncio.TimeoutError: + logger.debug("No event received from airflow trigger") + except AirflowException as e: + cur_state = RETRYABLE_FAILURE + message = e.__str__() + else: + # If there is no trigger, it means the operator is not deferrable. In this case, this operator will be + # executed in the creation step. Therefore, we can directly return SUCCEEDED here. + # For instance, SlackWebhookOperator is not deferrable. It sends a message to Slack in the creation step. + # If the message is sent successfully, agent will return SUCCEEDED here. Otherwise, it will raise an exception at creation step. cur_state = SUCCEEDED - elif issubclass(type(task), DataprocJobBaseOperator): - job = task.hook.get_job( - job_id=job_id, - region=airflow_config.task_config["region"], - project_id=airflow_config.task_config["project_id"], - ) - if job.status.state == JobStatus.State.DONE: - cur_state = SUCCEEDED - elif job.status.state in (JobStatus.State.ERROR, JobStatus.State.CANCELLED): - cur_state = PERMANENT_FAILURE - elif isinstance(task, DataprocDeleteClusterOperator): - try: - task.execute(context=Context()) - except NotFound: - logger.info("Cluster already deleted.") - cur_state = SUCCEEDED + else: - task.execute(context=Context()) - cur_state = SUCCEEDED - return GetTaskResponse(resource=Resource(state=cur_state, outputs=None)) + raise FlyteUserException("Only sensor and operator are supported.") + + return GetTaskResponse(resource=Resource(state=cur_state, message=message)) - def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: return DeleteTaskResponse() diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py index a25a46cf1e..a5c10346a6 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py @@ -1,3 +1,4 @@ +import importlib import logging import typing from dataclasses import dataclass @@ -5,47 +6,167 @@ import jsonpickle from airflow import DAG +from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.sensors.base import BaseSensorOperator +from airflow.triggers.base import BaseTrigger +from airflow.utils.context import Context -from flytekit import FlyteContextManager +from flytekit import FlyteContextManager, logger from flytekit.configuration import SerializationSettings -from flytekit.core.base_task import PythonTask +from flytekit.core.base_task import PythonTask, TaskResolverMixin from flytekit.core.interface import Interface +from flytekit.core.python_auto_container import PythonAutoContainerTask +from flytekit.core.tracker import TrackedInstance +from flytekit.core.utils import timeit from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin @dataclass -class AirflowConfig(object): - task_module: str - task_name: str - task_config: typing.Dict[str, Any] +class AirflowObj(object): + """ + This class is used to store the Airflow task configuration. It is serialized and stored in the Flyte task config. + It can be trigger, hook, operator or sensor. For example: + + from airflow.sensors.filesystem import FileSensor + sensor = FileSensor(task_id="id", filepath="/tmp/1234") + + In this case, the attributes of AirflowObj will be: + module: airflow.sensors.filesystem + name: FileSensor + parameters: {"task_id": "id", "filepath": "/tmp/1234"} + """ + + module: str + name: str + parameters: typing.Dict[str, Any] + + +class AirflowTaskResolver(TrackedInstance, TaskResolverMixin): + """ + This class is used to resolve an Airflow task. It will load an airflow task in the container. + """ + + def name(self) -> str: + return "AirflowTaskResolver" + + @timeit("Load airflow task") + def load_task(self, loader_args: typing.List[str]) -> typing.Union[BaseOperator, BaseSensorOperator, BaseTrigger]: + """ + This method is used to load an Airflow task. + """ + _, task_module, _, task_name, _, task_config = loader_args + task_module = importlib.import_module(name=task_module) # type: ignore + task_def = getattr(task_module, task_name) + return task_def(name=task_name, task_config=jsonpickle.decode(task_config)) + + def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> typing.List[str]: + return [ + "task-module", + task.__module__, + "task-name", + task.__class__.__name__, + "task-config", + jsonpickle.encode(task.task_config), + ] + + def get_all_tasks(self) -> typing.List[PythonAutoContainerTask]: # type: ignore + raise Exception("should not be needed") -class AirflowTask(AsyncAgentExecutorMixin, PythonTask[AirflowConfig]): +airflow_task_resolver = AirflowTaskResolver() + + +class AirflowContainerTask(PythonAutoContainerTask[AirflowObj]): + """ + This python container task is used to wrap an Airflow task. It is used to run an Airflow task in a container. + The airflow task module, name and parameters are stored in the task config. + + Some of the Airflow operators are not deferrable, For example, BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator. + These tasks don't have async method to get the job status, so cannot be used in the Flyte agent. We run these tasks in a container. + """ + + def __init__( + self, + name: str, + task_config: AirflowObj, + inputs: Optional[Dict[str, Type]] = None, + **kwargs, + ): + super().__init__( + name=name, + task_config=task_config, + interface=Interface(inputs=inputs or {}), + **kwargs, + ) + self._task_resolver = airflow_task_resolver + + def execute(self, **kwargs) -> Any: + logger.info("Executing Airflow task") + _get_airflow_instance(self.task_config).execute(context=Context()) + + +class AirflowTask(AsyncAgentExecutorMixin, PythonTask[AirflowObj]): + """ + This python task is used to wrap an Airflow task. It is used to run an Airflow task in Flyte agent. + The airflow task module, name and parameters are stored in the task config. We run the Airflow task in the agent. + """ + _TASK_TYPE = "airflow" def __init__( self, name: str, - query_template: str, - task_config: Optional[AirflowConfig], + task_config: Optional[AirflowObj], inputs: Optional[Dict[str, Type]] = None, **kwargs, ): super().__init__( name=name, task_config=task_config, - query_template=query_template, interface=Interface(inputs=inputs or {}), task_type=self._TASK_TYPE, **kwargs, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + # Use jsonpickle to serialize the Airflow task config since the return value should be json serializable. return {"task_config_pkl": jsonpickle.encode(self.task_config)} +def _get_airflow_instance(airflow_obj: AirflowObj) -> typing.Union[BaseOperator, BaseSensorOperator, BaseTrigger]: + # Set the GET_ORIGINAL_TASK attribute to True so that obj_def will return the original + # airflow task instead of the Flyte task. + ctx = FlyteContextManager.current_context() + ctx.user_space_params.builder().add_attr("GET_ORIGINAL_TASK", True).build() + + obj_module = importlib.import_module(name=airflow_obj.module) + obj_def = getattr(obj_module, airflow_obj.name) + if issubclass(obj_def, BaseOperator) and not issubclass(obj_def, BaseSensorOperator) and _is_deferrable(obj_def): + try: + return obj_def(**airflow_obj.parameters, deferrable=True) + except AirflowException as e: + logger.debug(f"Failed to create operator {airflow_obj.name} with err: {e}.") + logger.debug(f"Airflow operator {airflow_obj.name} does not support deferring.") + + return obj_def(**airflow_obj.parameters) + + +def _is_deferrable(cls: Type): + """ + This function is used to check if the Airflow operator is deferrable. + """ + try: + from airflow.providers.apache.beam.operators.beam import BeamBasePipelineOperator + + # Dataflow operators are not deferrable. + if not issubclass(cls, BeamBasePipelineOperator): + return False + except ImportError: + logger.debug("Failed to import BeamBasePipelineOperator") + return True + + def _flyte_operator(*args, **kwargs): """ This function is called by the Airflow operator to create a new task. We intercept this call and return a Flyte @@ -57,10 +178,18 @@ def _flyte_operator(*args, **kwargs): # Return original task when running in the agent. return object.__new__(cls) except AssertionError: + # This happens when the task is created in the dynamic workflow. + # We don't need to return the original task in this case. logging.debug("failed to get the attribute GET_ORIGINAL_TASK from user space params") - config = AirflowConfig(task_module=cls.__module__, task_name=cls.__name__, task_config=kwargs) - t = AirflowTask(name=kwargs["task_id"], query_template="", task_config=config, original_new=cls.__new__) - return t() + + container_image = kwargs.pop("container_image", None) + task_id = kwargs["task_id"] or cls.__name__ + config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs) + + if _is_deferrable(cls): + # Dataflow operators are not deferrable, so we run them in a container. + return AirflowContainerTask(name=task_id, task_config=config, container_image=container_image)() + return AirflowTask(name=task_id, task_config=config)() def _flyte_xcom_push(*args, **kwargs): @@ -68,12 +197,24 @@ def _flyte_xcom_push(*args, **kwargs): This function is called by the Airflow operator to push data to XCom. We intercept this call and store the data in the Flyte context. """ - FlyteContextManager.current_context().user_space_params.xcom_data = kwargs + if len(args) < 2: + return + # Store the XCom data in the Flyte context. + # args[0] is the operator instance. + # args[1:] are the XCom data. + # For example, + # op.xcom_push(Context(), "key", "value") + # args[0] is op, args[1:] is [Context(), "key", "value"] + FlyteContextManager.current_context().user_space_params.xcom_data = args[1:] params = FlyteContextManager.current_context().user_space_params params.builder().add_attr("GET_ORIGINAL_TASK", False).add_attr("XCOM_DATA", {}).build() +# Monkey patch the Airflow operator. Instead of creating an airflow task, it returns a Flyte task. BaseOperator.__new__ = _flyte_operator BaseOperator.xcom_push = _flyte_xcom_push +# Monkey patch the xcom_push method to store the data in the Flyte context. +# Create a dummy DAG to avoid Airflow errors. This DAG is not used. +# TODO: Add support using Airflow DAG in Flyte workflow. We can probably convert the Airflow DAG to a Flyte subworkflow. BaseSensorOperator.dag = DAG(dag_id="flyte_dag") diff --git a/plugins/flytekit-airflow/setup.py b/plugins/flytekit-airflow/setup.py index f077e174b1..91214e6dbf 100644 --- a/plugins/flytekit-airflow/setup.py +++ b/plugins/flytekit-airflow/setup.py @@ -6,10 +6,7 @@ plugin_requires = [ "apache-airflow", - "jsonpickle", "flytekit>=1.9.0", - "google-cloud-orchestration-airflow", - "apache-airflow-providers-google", ] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-airflow/tests/test_agent.py b/plugins/flytekit-airflow/tests/test_agent.py index 2f266e323a..7b46a5e7e6 100644 --- a/plugins/flytekit-airflow/tests/test_agent.py +++ b/plugins/flytekit-airflow/tests/test_agent.py @@ -1,10 +1,22 @@ +from unittest.mock import MagicMock from datetime import datetime, timedelta, timezone +import grpc +import jsonpickle +import pytest from airflow.operators.python import PythonOperator from airflow.sensors.bash import BashSensor from airflow.sensors.time_sensor import TimeSensor +from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse +from flytekitplugins.airflow import AirflowObj +from flytekitplugins.airflow.agent import AirflowAgent, ResourceMetadata from flytekit import workflow +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import interface as interface_models +from flytekit.models import literals, task +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import TaskTemplate def py_func(): @@ -22,5 +34,67 @@ def wf(): sensor >> t3 >> foo -def test_airflow_agent(): +def test_airflow_workflow(): wf() + + +def test_resource_metadata(): + task_cfg = AirflowObj( + module="airflow.operators.bash", + name="BashOperator", + parameters={"task_id": "id", "bash_command": "echo 'hello world'"}, + ) + trigger_cfg = AirflowObj(module="airflow.trigger.file", name="FileTrigger", parameters={"filepath": "file.txt"}) + meta = ResourceMetadata( + airflow_operator=task_cfg, + airflow_trigger=trigger_cfg, + airflow_trigger_callback="execute_complete", + job_id="123", + ) + assert meta.airflow_operator == task_cfg + assert meta.airflow_trigger == trigger_cfg + assert meta.airflow_trigger_callback == "execute_complete" + assert meta.job_id == "123" + + +@pytest.mark.asyncio +async def test_airflow_agent(): + cfg = AirflowObj( + module="airflow.operators.bash", + name="BashOperator", + parameters={"task_id": "id", "bash_command": "echo 'hello world'"}, + ) + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="airflow_Task", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + + interfaces = interface_models.TypedInterface(inputs={}, outputs={}) + + dummy_template = TaskTemplate( + id=task_id, + metadata=task_metadata, + interface=interfaces, + type="airflow", + custom={"task_config_pkl": jsonpickle.encode(cfg)}, + ) + + agent = AirflowAgent() + grpc_ctx = MagicMock(spec=grpc.ServicerContext) + res = await agent.async_create(grpc_ctx, "/tmp", dummy_template, None) + metadata = res.resource_meta + res = await agent.async_get(grpc_ctx, metadata) + assert res.resource.state == SUCCEEDED + assert res.resource.message == "" + res = await agent.async_delete(grpc_ctx, metadata) + assert res == DeleteTaskResponse() diff --git a/plugins/flytekit-airflow/tests/test_task.py b/plugins/flytekit-airflow/tests/test_task.py new file mode 100644 index 0000000000..3d6a954b7f --- /dev/null +++ b/plugins/flytekit-airflow/tests/test_task.py @@ -0,0 +1,87 @@ +import jsonpickle +from airflow.providers.apache.beam.operators.beam import BeamRunJavaPipelineOperator +from airflow.sensors.bash import BashSensor +from airflow.utils.context import Context +from flytekitplugins.airflow.task import ( + AirflowContainerTask, + AirflowObj, + AirflowTask, + _is_deferrable, + airflow_task_resolver, +) + +from flytekit import FlyteContextManager +from flytekit.configuration import ImageConfig, SerializationSettings + + +def test_xcom_push(): + ctx = FlyteContextManager.current_context() + ctx.user_space_params._attrs = {} + + execution_state = ctx.execution_state.with_params( + user_space_params=ctx.user_space_params.new_builder() + .add_attr("GET_ORIGINAL_TASK", True) + .add_attr("XCOM_DATA", {}) + .build() + ) + + with FlyteContextManager.with_context(ctx.with_execution_state(execution_state)) as child_ctx: + print(child_ctx.user_space_params.get_original_task) + op = BashSensor(task_id="Sensor_succeeds", bash_command="exit 0") + op.xcom_push(Context(), "key", "value") + assert child_ctx.user_space_params.xcom_data[1] == "key" + assert child_ctx.user_space_params.xcom_data[2] == "value" + + +def test_is_deferrable(): + assert _is_deferrable(BeamRunJavaPipelineOperator) is True + assert _is_deferrable(BashSensor) is False + + +def test_airflow_task(): + cfg = AirflowObj( + module="airflow.operators.bash", + name="BashOperator", + parameters={"task_id": "id", "bash_command": "echo 'hello world'"}, + ) + t = AirflowTask(name="test_bash_operator", task_config=cfg) + serialization_settings = SerializationSettings( + project="proj", + domain="dom", + version="123", + image_config=ImageConfig.auto(), + env={}, + ) + t.get_custom(serialization_settings)["task_config_pkl"] = jsonpickle.encode(cfg) + t.execute() + + +def test_airflow_container_task(): + cfg = AirflowObj( + module="airflow.providers.apache.beam.operators.beam", + name="BeamRunJavaPipelineOperator", + parameters={"task_id": "id", "job_class": "org.apache.beam.examples.WordCount"}, + ) + t = AirflowContainerTask(name="test_dataflow_operator", task_config=cfg) + serialization_settings = SerializationSettings( + project="proj", + domain="dom", + version="123", + image_config=ImageConfig.auto(), + env={}, + ) + assert t.task_resolver.name() == "AirflowTaskResolver" + assert t.task_resolver.loader_args(serialization_settings, t) == [ + "task-module", + "flytekitplugins.airflow.task", + "task-name", + "AirflowContainerTask", + "task-config", + '{"py/object": "flytekitplugins.airflow.task.AirflowObj", "module": ' + '"airflow.providers.apache.beam.operators.beam", "name": ' + '"BeamRunJavaPipelineOperator", "parameters": {"task_id": "id", "job_class": ' + '"org.apache.beam.examples.WordCount"}}', + ] + assert isinstance( + airflow_task_resolver.load_task(t.task_resolver.loader_args(serialization_settings, t)), AirflowContainerTask + ) diff --git a/plugins/flytekit-flyin/flytekitplugins/flyin/__init__.py b/plugins/flytekit-flyin/flytekitplugins/flyin/__init__.py index 391587faa3..da2612bb4f 100644 --- a/plugins/flytekit-flyin/flytekitplugins/flyin/__init__.py +++ b/plugins/flytekit-flyin/flytekitplugins/flyin/__init__.py @@ -19,9 +19,11 @@ COPILOT_CONFIG CODE_TOGETHER_CONFIG jupyter + get_task_inputs """ from .jupyter_lib.decorator import jupyter +from .utils import get_task_inputs from .vscode_lib.config import ( CODE_TOGETHER_CONFIG, CODE_TOGETHER_EXTENSION, diff --git a/plugins/flytekit-flyin/flytekitplugins/flyin/utils.py b/plugins/flytekit-flyin/flytekitplugins/flyin/utils.py new file mode 100644 index 0000000000..7e879583ff --- /dev/null +++ b/plugins/flytekit-flyin/flytekitplugins/flyin/utils.py @@ -0,0 +1,56 @@ +import importlib +import os +import sys + +from flytekit.core import utils +from flytekit.core.context_manager import FlyteContextManager +from flyteidl.core import literals_pb2 as _literals_pb2 +from flytekit.core.type_engine import TypeEngine +from flytekit.models import literals as _literal_models + + +def load_module_from_path(module_name, path): + """ + Imports a Python module from a specified file path. + + Args: + module_name (str): The name you want to assign to the imported module. + path (str): The file system path to the Python file (.py) that contains the module you want to import. + + Returns: + module: The imported module. + + Raises: + ImportError: If the module cannot be loaded from the provided path, an ImportError is raised. + """ + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is not None: + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + else: + raise ImportError(f"Module at {path} could not be loaded") + + +def get_task_inputs(task_module_name, task_name, context_working_dir): + """ + Read task input data from inputs.pb for a specific task function and convert it into Python types and structures. + + Args: + task_module_name (str): The name of the Python module containing the task function. + task_name (str): The name of the task function within the module. + context_working_dir (str): The directory path where the input file and module file are located. + + Returns: + dict: A dictionary containing the task inputs, converted into Python types and structures. + """ + local_inputs_file = os.path.join(context_working_dir, "inputs.pb") + input_proto = utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) + idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto) + task_module = load_module_from_path(task_module_name, os.path.join(context_working_dir, f"{task_module_name}.py")) + task_def = getattr(task_module, task_name) + native_inputs = TypeEngine.literal_map_to_kwargs( + FlyteContextManager(), idl_input_literals, task_def.python_interface.inputs + ) + return native_inputs diff --git a/plugins/flytekit-flyin/flytekitplugins/flyin/vscode_lib/constants.py b/plugins/flytekit-flyin/flytekitplugins/flyin/vscode_lib/constants.py index f8c971eaee..cf9022dce0 100644 --- a/plugins/flytekit-flyin/flytekitplugins/flyin/vscode_lib/constants.py +++ b/plugins/flytekit-flyin/flytekitplugins/flyin/vscode_lib/constants.py @@ -23,3 +23,5 @@ # The path is hardcoded by code-server # https://coder.com/docs/code-server/latest/FAQ#what-is-the-heartbeat-file HEARTBEAT_PATH = os.path.expanduser("~/.local/share/code-server/heartbeat") + +INTERACTIVE_DEBUGGING_FILE_NAME = "flyin_interactive_entrypoint.py" diff --git a/plugins/flytekit-flyin/flytekitplugins/flyin/vscode_lib/decorator.py b/plugins/flytekit-flyin/flytekitplugins/flyin/vscode_lib/decorator.py index 53e7a52f08..9c461ecc31 100644 --- a/plugins/flytekit-flyin/flytekitplugins/flyin/vscode_lib/decorator.py +++ b/plugins/flytekit-flyin/flytekitplugins/flyin/vscode_lib/decorator.py @@ -1,3 +1,4 @@ +import json import multiprocessing import os import shutil @@ -12,9 +13,16 @@ import flytekit from flytekit.core.context_manager import FlyteContextManager - +import flytekit from .config import VscodeConfig -from .constants import DOWNLOAD_DIR, EXECUTABLE_NAME, HEARTBEAT_CHECK_SECONDS, HEARTBEAT_PATH, MAX_IDLE_SECONDS +from .constants import ( + DOWNLOAD_DIR, + EXECUTABLE_NAME, + HEARTBEAT_CHECK_SECONDS, + HEARTBEAT_PATH, + MAX_IDLE_SECONDS, + INTERACTIVE_DEBUGGING_FILE_NAME, +) def execute_command(cmd): @@ -149,6 +157,64 @@ def download_vscode(vscode_config: VscodeConfig): execute_command(f"code-server --install-extension {p}") +def prepare_interactive_python(task_function): + """ + 1. Copy the original task file to the context working directory. This ensures that the inputs.pb can be loaded, as loading requires the original task interface. + By doing so, even if users change the task interface in their code, we can use the copied task file to load the inputs as native Python objects. + 2. Generate a Python script and a launch.json for users to debug interactively. + + Args: + task_function (function): User's task function. + """ + + context_working_dir = FlyteContextManager.current_context().execution_state.working_dir + + # Copy the user's Python file to the working directory. + shutil.copy(f"{task_function.__module__}.py", os.path.join(context_working_dir, f"{task_function.__module__}.py")) + + # Generate a Python script + task_module_name, task_name = task_function.__module__, task_function.__name__ + python_script = f"""# This file is auto-generated by flyin + +from {task_module_name} import {task_name} +from flytekitplugins.flyin import get_task_inputs + +if __name__ == "__main__": + inputs = get_task_inputs( + task_module_name="{task_module_name}", + task_name="{task_name}", + context_working_dir="{context_working_dir}", + ) + # You can modify the inputs! Ex: inputs['a'] = 5 + print({task_name}(**inputs)) +""" + + with open(INTERACTIVE_DEBUGGING_FILE_NAME, "w") as file: + file.write(python_script) + + # Generate a launch.json + launch_json = { + "version": "0.2.0", + "configurations": [ + { + "name": "Interactive Debugging", + "type": "python", + "request": "launch", + "program": os.path.join(os.getcwd(), INTERACTIVE_DEBUGGING_FILE_NAME), + "console": "integratedTerminal", + "justMyCode": True, + } + ], + } + + vscode_directory = ".vscode" + if not os.path.exists(vscode_directory): + os.makedirs(vscode_directory) + + with open(os.path.join(vscode_directory, "launch.json"), "w") as file: + json.dump(launch_json, file, indent=4) + + def vscode( _task_function: Optional[Callable] = None, max_idle_seconds: Optional[int] = MAX_IDLE_SECONDS, @@ -163,8 +229,9 @@ def vscode( vscode decorator modifies a container to run a VSCode server: 1. Overrides the user function with a VSCode setup function. 2. Download vscode server and extension from remote to local. - 3. Launches and monitors the VSCode server. - 4. Terminates if the server is idle for a set duration. + 3. Prepare the interactive debugging Python script and launch.json. + 4. Launches and monitors the VSCode server. + 5. Terminates if the server is idle for a set duration. Args: _task_function (function, optional): The user function to be decorated. Defaults to None. @@ -186,11 +253,12 @@ def wrapper(fn): @wraps(fn) def inner_wrapper(*args, **kwargs): + ctx = FlyteContextManager.current_context() logger = flytekit.current_context().logging # When user use pyflyte run or python to execute the task, we don't launch the VSCode server. # Only when user use pyflyte run --remote to submit the task to cluster, we launch the VSCode server. - if FlyteContextManager.current_context().execution_state.is_local_execution(): + if ctx.execution_state.is_local_execution(): return fn(*args, **kwargs) if run_task_first: @@ -209,7 +277,10 @@ def inner_wrapper(*args, **kwargs): # 1. Downloads the VSCode server from Internet to local. download_vscode(config) - # 2. Launches and monitors the VSCode server. + # 2. Prepare the interactive debugging Python script and launch.json. + prepare_interactive_python(fn) + + # 3. Launches and monitors the VSCode server. # Run the function in the background child_process = multiprocessing.Process( target=execute_command, diff --git a/plugins/flytekit-flyin/tests/test_flyin_plugin.py b/plugins/flytekit-flyin/tests/test_flyin_plugin.py index c614a3d2aa..00a5990cd3 100644 --- a/plugins/flytekit-flyin/tests/test_flyin_plugin.py +++ b/plugins/flytekit-flyin/tests/test_flyin_plugin.py @@ -32,9 +32,12 @@ def mock_remote_execution(): @mock.patch("multiprocessing.Process") +@mock.patch("flytekitplugins.flyin.vscode_lib.decorator.prepare_interactive_python") @mock.patch("flytekitplugins.flyin.vscode_lib.decorator.exit_handler") @mock.patch("flytekitplugins.flyin.vscode_lib.decorator.download_vscode") -def test_vscode_remote_execution(mock_download_vscode, mock_exit_handler, mock_process, mock_remote_execution): +def test_vscode_remote_execution( + mock_download_vscode, mock_exit_handler, mock_process, mock_prepare_interactive_python, mock_remote_execution +): @task @vscode def t(): @@ -48,12 +51,16 @@ def wf(): mock_download_vscode.assert_called_once() mock_process.assert_called_once() mock_exit_handler.assert_called_once() + mock_prepare_interactive_python.assert_called_once() @mock.patch("multiprocessing.Process") +@mock.patch("flytekitplugins.flyin.vscode_lib.decorator.prepare_interactive_python") @mock.patch("flytekitplugins.flyin.vscode_lib.decorator.exit_handler") @mock.patch("flytekitplugins.flyin.vscode_lib.decorator.download_vscode") -def test_vscode_local_execution(mock_download_vscode, mock_exit_handler, mock_process, mock_local_execution): +def test_vscode_local_execution( + mock_download_vscode, mock_exit_handler, mock_process, mock_prepare_interactive_python, mock_local_execution +): @task @vscode def t(): @@ -67,6 +74,7 @@ def wf(): mock_download_vscode.assert_not_called() mock_process.assert_not_called() mock_exit_handler.assert_not_called() + mock_prepare_interactive_python.assert_not_called() def test_vscode_run_task_first_succeed(mock_remote_execution): @@ -85,9 +93,12 @@ def wf(a: int, b: int) -> int: @mock.patch("multiprocessing.Process") +@mock.patch("flytekitplugins.flyin.vscode_lib.decorator.prepare_interactive_python") @mock.patch("flytekitplugins.flyin.vscode_lib.decorator.exit_handler") @mock.patch("flytekitplugins.flyin.vscode_lib.decorator.download_vscode") -def test_vscode_run_task_first_fail(mock_download_vscode, mock_exit_handler, mock_process, mock_remote_execution): +def test_vscode_run_task_first_fail( + mock_download_vscode, mock_exit_handler, mock_process, mock_prepare_interactive_python, mock_remote_execution +): @task @vscode def t(a: int, b: int): @@ -102,6 +113,7 @@ def wf(a: int, b: int): mock_download_vscode.assert_called_once() mock_process.assert_called_once() mock_exit_handler.assert_called_once() + mock_prepare_interactive_python.assert_called_once() @mock.patch("flytekitplugins.flyin.jupyter_lib.decorator.subprocess.Popen") diff --git a/plugins/flytekit-flyin/tests/test_utils.py b/plugins/flytekit-flyin/tests/test_utils.py new file mode 100644 index 0000000000..99e73deba5 --- /dev/null +++ b/plugins/flytekit-flyin/tests/test_utils.py @@ -0,0 +1,20 @@ +import os + +from flytekitplugins.flyin import get_task_inputs +from flytekitplugins.flyin.utils import load_module_from_path + + +def test_load_module_from_path(): + module_name = "task" + module_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testdata", "task.py") + task_name = "t1" + task_module = load_module_from_path(module_name, module_path) + assert hasattr(task_module, task_name) + task_def = getattr(task_module, task_name) + assert task_def(a=6, b=3) == 2 + + +def test_get_task_inputs(): + test_working_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testdata") + native_inputs = get_task_inputs("task", "t1", test_working_dir) + assert native_inputs == {"a": 30, "b": 0} diff --git a/plugins/flytekit-flyin/tests/testdata/inputs.pb b/plugins/flytekit-flyin/tests/testdata/inputs.pb new file mode 100644 index 0000000000..3cb3128503 Binary files /dev/null and b/plugins/flytekit-flyin/tests/testdata/inputs.pb differ diff --git a/plugins/flytekit-flyin/tests/testdata/task.py b/plugins/flytekit-flyin/tests/testdata/task.py new file mode 100644 index 0000000000..ba035fcc84 --- /dev/null +++ b/plugins/flytekit-flyin/tests/testdata/task.py @@ -0,0 +1,8 @@ +from flytekitplugins.flyin import vscode +from flytekit import task + + +@task() +@vscode(run_task_first=True) +def t1(a: int, b: int) -> int: + return a // b diff --git a/plugins/flytekit-spark/dev-requirements.txt b/plugins/flytekit-spark/dev-requirements.txt index c6617d6228..3335091569 100644 --- a/plugins/flytekit-spark/dev-requirements.txt +++ b/plugins/flytekit-spark/dev-requirements.txt @@ -4,9 +4,9 @@ # # pip-compile dev-requirements.in # -aiohttp==3.8.6 +aiohttp==3.9.1 # via aioresponses -aioresponses==0.7.4 +aioresponses==0.7.6 # via -r dev-requirements.in aiosignal==1.3.1 # via aiohttp @@ -14,15 +14,13 @@ async-timeout==4.0.3 # via aiohttp attrs==23.1.0 # via aiohttp -charset-normalizer==3.2.0 - # via aiohttp -exceptiongroup==1.1.3 +exceptiongroup==1.2.0 # via pytest frozenlist==1.4.0 # via # aiohttp # aiosignal -idna==3.4 +idna==3.6 # via yarl iniconfig==2.0.0 # via pytest @@ -30,15 +28,15 @@ multidict==6.0.4 # via # aiohttp # yarl -packaging==23.1 +packaging==23.2 # via pytest pluggy==1.3.0 # via pytest -pytest==7.4.0 +pytest==7.4.3 # via pytest-asyncio pytest-asyncio==0.21.1 # via -r dev-requirements.in tomli==2.0.1 # via pytest -yarl==1.9.2 +yarl==1.9.3 # via aiohttp diff --git a/setup.py b/setup.py index fc8e971670..c3624921aa 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ "pandas>=1.0.0,<2.0.0", # TODO: Remove upper-bound after protobuf community fixes it. https://github.com/flyteorg/flyte/issues/4359 "protobuf<4.25.0", - "pyarrow>=4.0.0,<11.0.0", + "pyarrow>=4.0.0", "python-json-logger>=2.0.0", "pytimeparse>=1.1.8,<2.0.0", "pyyaml!=6.0.0,!=5.4.0,!=5.4.1", # pyyaml is broken with cython 3: https://github.com/yaml/pyyaml/issues/601 @@ -68,7 +68,6 @@ "statsd>=3.0.0,<4.0.0", "typing_extensions", "urllib3>=1.22,<2.0.0", - "wheel>=0.30.0,<1.0.0", ], extras_require=extras_require, scripts=[ diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index 08cc149161..3d4c3fd037 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -8,7 +8,7 @@ adlfs==2023.4.0 # via flytekit aiobotocore==2.5.2 # via s3fs -aiohttp==3.8.6 +aiohttp==3.9.1 # via # adlfs # aiobotocore @@ -20,7 +20,7 @@ aiosignal==1.3.1 # via aiohttp arrow==1.2.3 # via cookiecutter -async-timeout==4.0.2 +async-timeout==4.0.3 # via aiohttp attrs==23.1.0 # via aiohttp @@ -68,7 +68,7 @@ cookiecutter==2.2.3 # via flytekit croniter==1.4.1 # via flytekit -cryptography==41.0.4 +cryptography==41.0.6 # via # azure-identity # azure-storage-blob @@ -348,7 +348,7 @@ wrapt==1.15.0 # via # aiobotocore # deprecated -yarl==1.9.2 +yarl==1.9.3 # via aiohttp zipp==3.16.2 # via importlib-metadata diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py index b5fa46fe54..f05d37bd10 100644 --- a/tests/flytekit/unit/core/test_checkpoint.py +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -6,6 +6,7 @@ import flytekit from flytekit.core.checkpointer import SyncCheckpoint from flytekit.core.local_cache import LocalTaskCache +from flytekit.exceptions.user import FlyteAssertion def test_sync_checkpoint_write(tmpdir): @@ -72,6 +73,31 @@ def test_sync_checkpoint_restore(tmpdir): assert cp.restore("other_path") == user_dest +def test_sync_checkpoint_restore_corrupt(tmpdir): + td_path = Path(tmpdir) + dest = td_path.joinpath("dest") + dest.mkdir() + src = td_path.joinpath("src") + src.mkdir() + prev = src.joinpath("prev") + p = b"prev-bytes" + with prev.open("wb") as f: + f.write(p) + cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src)) + user_dest = td_path.joinpath("user_dest") + user_dest.mkdir() + + # Simulate a failed upload of the checkpoint e.g. due to preemption + prev.unlink() + src.rmdir() + + with pytest.raises(FlyteAssertion): + cp.restore(user_dest) + + with pytest.raises(FlyteAssertion): + cp.restore(user_dest) + + def test_sync_checkpoint_restore_default_path(tmpdir): td_path = Path(tmpdir) dest = td_path.joinpath("dest") diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index d550c50ac4..227be2d0ff 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -6,7 +6,7 @@ import typing from dataclasses import asdict, dataclass, field from datetime import timedelta -from enum import Enum +from enum import Enum, auto from typing import Optional, Type import mock @@ -33,6 +33,7 @@ from flytekit.core.task import task from flytekit.core.type_engine import ( DataclassTransformer, + EnumTransformer, DictTransformer, ListTransformer, LiteralsResolver, @@ -1246,6 +1247,12 @@ class Color(Enum): BLUE = "blue" +class MultiInheritanceColor(str, Enum): + RED = auto() + GREEN = auto() + BLUE = auto() + + # Enums with integer values are not supported class UnsupportedEnumValues(Enum): RED = 1 @@ -1331,6 +1338,11 @@ def test_enum_type(): TypeEngine.to_literal_type(UnsupportedEnumValues) +def test_multi_inheritance_enum_type(): + tfm = TypeEngine.get_transformer(MultiInheritanceColor) + assert isinstance(tfm, EnumTransformer) + + def union_type_tags_unique(t: LiteralType): seen = set() for x in t.union_type.variants: