diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 5fd44b1c0e..67af8f9575 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -356,6 +356,7 @@ jobs: - flytekit-papermill - flytekit-polars - flytekit-pydantic + - flytekit-pydantic-v2 - flytekit-ray - flytekit-snowflake - flytekit-spark diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 70344b6a86..0de3214f77 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -232,6 +232,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: ) def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: + """ + This is for dict, dataclass, and dataclass attribute access. + """ if binary_idl_object.tag == MESSAGEPACK: try: decoder = self._msgpack_decoder[expected_python_type] @@ -242,6 +245,12 @@ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[ else: raise TypeTransformerFailedError(f"Unsupported binary format `{binary_idl_object.tag}`") + def from_generic_idl(self, generic: Struct, expected_python_type: Type[T]) -> Optional[T]: + """ + This is for dataclass attribute access from input created from the Flyte Console. + """ + raise NotImplementedError(f"Conversion from generic idl to python type {expected_python_type} not implemented") + def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div @@ -2256,6 +2265,17 @@ def _check_and_covert_float(lv: Literal) -> float: raise TypeTransformerFailedError(f"Cannot convert literal {lv} to float") +def _check_and_covert_int(lv: Literal) -> int: + if lv.scalar.primitive.integer is not None: + return lv.scalar.primitive.integer + + if lv.scalar.primitive.float_value is not None: + logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.") + return int(lv.scalar.primitive.float_value) + + raise TypeTransformerFailedError(f"Cannot convert literal {lv} to int") + + def _check_and_convert_void(lv: Literal) -> None: if lv.scalar.none_type is None: raise TypeTransformerFailedError(f"Cannot convert literal {lv} to None") @@ -2269,7 +2289,7 @@ def _register_default_type_transformers(): int, _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER), lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))), - lambda x: x.scalar.primitive.integer, + _check_and_covert_int, ) ) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 04a1848f84..0ac5935da0 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -36,7 +36,10 @@ def is_pydantic_basemodel(python_type: typing.Type) -> bool: return False else: try: - from pydantic.v1 import BaseModel + from pydantic import BaseModel as BaseModelV2 + from pydantic.v1 import BaseModel as BaseModelV1 + + return issubclass(python_type, BaseModelV1) or issubclass(python_type, BaseModelV2) except ImportError: from pydantic import BaseModel diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index 87b494d0ae..1bf9f30196 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -10,13 +10,14 @@ :template: file_types.rst FlyteDirectory + FlyteDirToMultipartBlobTransformer TensorboardLogs TFRecordsDirectory """ import typing -from .types import FlyteDirectory +from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer # The following section provides some predefined aliases for commonly used FlyteDirectory formats. diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 518525914d..86dc33b2ae 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import os import pathlib import random @@ -13,6 +14,8 @@ import msgpack from dataclasses_json import DataClassJsonMixin, config from fsspec.utils import get_protocol +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.types import SerializableType @@ -535,13 +538,45 @@ def from_binary_idl( else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl(self, generic: Struct, expected_python_type: typing.Type[FlyteDirectory]) -> FlyteDirectory: + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + path = python_val.get("path", None) + + if path is None: + raise ValueError("FlyteDirectory's path should not be None") + + return FlyteDirToMultipartBlobTransformer().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ) + ), + uri=path, + ) + ) + ), + expected_python_type, + ) + def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory] ) -> FlyteDirectory: - if lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) - - uri = lv.scalar.blob.uri + # Handle dataclass attribute access + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) + + try: + uri = lv.scalar.blob.uri + except AttributeError: + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") if lv.scalar.blob.metadata.type.dimensionality != BlobType.BlobDimensionality.MULTIPART: raise TypeTransformerFailedError(f"{lv.scalar.blob.uri} is not a directory.") diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py index 838516f33d..bf0e42fbb8 100644 --- a/flytekit/types/file/__init__.py +++ b/flytekit/types/file/__init__.py @@ -10,6 +10,7 @@ :template: file_types.rst FlyteFile + FlyteFilePathTransformer HDF5EncodedFile HTMLPage JoblibSerializedFile @@ -25,7 +26,7 @@ from typing_extensions import Annotated, get_args, get_origin -from .file import FlyteFile +from .file import FlyteFile, FlyteFilePathTransformer class FileExt: diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 602f5bc12e..54cd4c53e8 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import mimetypes import os import pathlib @@ -11,6 +12,8 @@ import msgpack from dataclasses_json import config +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -554,12 +557,42 @@ def from_binary_idl( else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl( + self, generic: Struct, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] + ) -> FlyteFile: + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + path = python_val.get("path", None) + + if path is None: + raise ValueError("FlyteFile's path should not be None") + + return FlyteFilePathTransformer().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ), + uri=path, + ) + ) + ), + expected_python_type, + ) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] ) -> FlyteFile: # Handle dataclass attribute access - if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) try: uri = lv.scalar.blob.uri diff --git a/flytekit/types/schema/__init__.py b/flytekit/types/schema/__init__.py index 080927021a..33ee8ef72c 100644 --- a/flytekit/types/schema/__init__.py +++ b/flytekit/types/schema/__init__.py @@ -1,5 +1,6 @@ from .types import ( FlyteSchema, + FlyteSchemaTransformer, LocalIOSchemaReader, LocalIOSchemaWriter, SchemaEngine, diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 5cf8308b03..45d7fd28a5 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +import json import os import typing from abc import abstractmethod @@ -11,6 +12,8 @@ import msgpack from dataclasses_json import config +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -458,10 +461,28 @@ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[ else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl(self, generic: Struct, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + + remote_path = python_val.get("remote_path", None) + if remote_path is None: + raise ValueError("FlyteSchema's path should not be None") + + t = FlyteSchemaTransformer() + return t.to_python_value( + FlyteContextManager.current_context(), + Literal(scalar=Scalar(schema=Schema(remote_path, t._get_schema_type(expected_python_type)))), + expected_python_type, + ) + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: # Handle dataclass attribute access - if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) def downloader(x, y): ctx.file_access.get_data(x, y, is_multipart=True) diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 05d1fa86e3..1105cd23d5 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -7,9 +7,12 @@ :template: custom.rst :toctree: generated/ - StructuredDataset - StructuredDatasetEncoder - StructuredDatasetDecoder + StructuredDataset + StructuredDatasetDecoder + StructuredDatasetEncoder + StructuredDatasetMetadata + StructuredDatasetTransformerEngine + StructuredDatasetType """ from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer @@ -19,7 +22,9 @@ StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, + StructuredDatasetMetadata, StructuredDatasetTransformerEngine, + StructuredDatasetType, ) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 75b20fe08c..e8f7929d60 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -2,6 +2,7 @@ import _datetime import collections +import json import types import typing from abc import ABC, abstractmethod @@ -11,6 +12,8 @@ import msgpack from dataclasses_json import config from fsspec.utils import get_protocol +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -745,6 +748,33 @@ def from_binary_idl( else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl( + self, generic: Struct, expected_python_type: Type[T] | StructuredDataset + ) -> T | StructuredDataset: + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + + uri = python_val.get("uri", None) + file_format = python_val.get("file_format", None) + + if uri is None: + raise ValueError("StructuredDataset's uri and file format should not be None") + + return StructuredDatasetTransformerEngine().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + structured_dataset=StructuredDataset( + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType(format=file_format) + ), + uri=uri, + ) + ) + ), + expected_python_type, + ) + def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset ) -> T | StructuredDataset: @@ -779,8 +809,11 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ... +-----------------------------+-----------------------------------------+--------------------------------------+ """ # Handle dataclass attribute access - if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) # Detect annotations and extract out all the relevant information that the user might supply expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type) diff --git a/plugins/flytekit-pydantic-v2/README.md b/plugins/flytekit-pydantic-v2/README.md new file mode 100644 index 0000000000..a8ec17a117 --- /dev/null +++ b/plugins/flytekit-pydantic-v2/README.md @@ -0,0 +1 @@ +# TMP diff --git a/plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/__init__.py b/plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/__init__.py new file mode 100644 index 0000000000..99f2aa0f23 --- /dev/null +++ b/plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/__init__.py @@ -0,0 +1,25 @@ +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile +from flytekit.types.schema import FlyteSchema +from flytekit.types.structured import StructuredDataset + +from . import transformer +from .custom import ( + deserialize_flyte_dir, + deserialize_flyte_file, + deserialize_flyte_schema, + deserialize_structured_dataset, + serialize_flyte_dir, + serialize_flyte_file, + serialize_flyte_schema, + serialize_structured_dataset, +) + +setattr(FlyteFile, "serialize_flyte_file", serialize_flyte_file) +setattr(FlyteFile, "deserialize_flyte_file", deserialize_flyte_file) +setattr(FlyteDirectory, "serialize_flyte_dir", serialize_flyte_dir) +setattr(FlyteDirectory, "deserialize_flyte_dir", deserialize_flyte_dir) +setattr(FlyteSchema, "serialize_flyte_schema", serialize_flyte_schema) +setattr(FlyteSchema, "deserialize_flyte_schema", deserialize_flyte_schema) +setattr(StructuredDataset, "serialize_structured_dataset", serialize_structured_dataset) +setattr(StructuredDataset, "deserialize_structured_dataset", deserialize_structured_dataset) diff --git a/plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/custom.py b/plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/custom.py new file mode 100644 index 0000000000..7e536ae5d4 --- /dev/null +++ b/plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/custom.py @@ -0,0 +1,130 @@ +from typing import Dict + +from flytekit.core.context_manager import FlyteContextManager +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Schema +from flytekit.types.directory import FlyteDirectory, FlyteDirToMultipartBlobTransformer +from flytekit.types.file import FlyteFile, FlyteFilePathTransformer +from flytekit.types.schema import FlyteSchema, FlyteSchemaTransformer +from flytekit.types.structured import ( + StructuredDataset, + StructuredDatasetMetadata, + StructuredDatasetTransformerEngine, + StructuredDatasetType, +) +from pydantic import model_serializer, model_validator + + +@model_serializer +def serialize_flyte_file(self) -> Dict[str, str]: + lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"path": lv.scalar.blob.uri} + + +@model_validator(mode="after") +def deserialize_flyte_file(self) -> FlyteFile: + pv = FlyteFilePathTransformer().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ), + uri=self.path, + ) + ) + ), + type(self), + ) + pv._remote_path = None + return pv + + +@model_serializer +def serialize_flyte_dir(self) -> Dict[str, str]: + lv = FlyteDirToMultipartBlobTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"path": lv.scalar.blob.uri} + + +@model_validator(mode="after") +def deserialize_flyte_dir(self) -> FlyteDirectory: + pv = FlyteDirToMultipartBlobTransformer().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ) + ), + uri=self.path, + ) + ) + ), + type(self), + ) + pv._remote_directory = None + return pv + + +@model_serializer +def serialize_flyte_schema(self) -> Dict[str, str]: + FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"remote_path": self.remote_path} + + +@model_validator(mode="after") +def deserialize_flyte_schema(self) -> FlyteSchema: + # If we call the method to_python_value, FlyteSchemaTransformer will overwrite the local_path, + # which will lose our data. + # If this data is from an existed FlyteSchema, local path will be None. + + if hasattr(self, "_local_path"): + return self + + t = FlyteSchemaTransformer() + return t.to_python_value( + FlyteContextManager.current_context(), + Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))), + type(self), + ) + + +@model_serializer +def serialize_structured_dataset(self) -> Dict[str, str]: + lv = StructuredDatasetTransformerEngine().to_literal(FlyteContextManager.current_context(), self, type(self), None) + sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri) + sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format + return { + "uri": sd.uri, + "file_format": sd.file_format, + } + + +@model_validator(mode="after") +def deserialize_structured_dataset(self) -> StructuredDataset: + # If we call the method to_python_value, StructuredDatasetTransformerEngine will overwrite the 'dataframe', + # which will lose our data. + # If this data is from an existed StructuredDataset, dataframe will be None. + + if hasattr(self, "dataframe"): + return self + + return StructuredDatasetTransformerEngine().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + structured_dataset=StructuredDataset( + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType(format=self.file_format) + ), + uri=self.uri, + ) + ) + ), + type(self), + ) diff --git a/plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/transformer.py b/plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/transformer.py new file mode 100644 index 0000000000..2749ddd829 --- /dev/null +++ b/plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/transformer.py @@ -0,0 +1,75 @@ +import json +from typing import Type + +import msgpack +from google.protobuf import json_format as _json_format + +from flytekit import FlyteContext +from flytekit.core.constants import MESSAGEPACK +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.loggers import logger +from flytekit.models import types +from flytekit.models.literals import Binary, Literal, Scalar +from flytekit.models.types import LiteralType, TypeStructure +from pydantic import BaseModel + + +class PydanticTransformer(TypeTransformer[BaseModel]): + def __init__(self): + super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False) + + def get_literal_type(self, t: Type[BaseModel]) -> LiteralType: + schema = t.model_json_schema() + literal_type = {} + fields = t.__annotations__.items() + + for name, python_type in fields: + try: + literal_type[name] = TypeEngine.to_literal_type(python_type) + except Exception as e: + logger.warning( + "Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e) + ) + + ts = TypeStructure(tag="", dataclass_type=literal_type) + + return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts) + + def to_literal( + self, + ctx: FlyteContext, + python_val: BaseModel, + python_type: Type[BaseModel], + expected: types.LiteralType, + ) -> Literal: + json_str = python_val.model_dump_json() + dict_obj = json.loads(json_str) + msgpack_bytes = msgpack.dumps(dict_obj) + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK))) + + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel: + if binary_idl_object.tag == MESSAGEPACK: + dict_obj = msgpack.loads(binary_idl_object.value, raw=False, strict_map_key=False) + json_str = json.dumps(dict_obj) + python_val = expected_python_type.model_validate_json(json_data=json_str, strict=False) + return python_val + else: + raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel: + """ + There will have 2 kinds of literal values: + 1. protobuf Struct (From Flyte Console) + 2. binary scalar (Others) + Hence we have to handle 2 kinds of cases. + """ + if lv and lv.scalar and lv.scalar.binary is not None: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + + json_str = _json_format.MessageToJson(lv.scalar.generic) + dict_obj = json.loads(json_str) + python_val = expected_python_type.model_validate(obj=dict_obj, strict=False) + return python_val + + +TypeEngine.register(PydanticTransformer()) diff --git a/plugins/flytekit-pydantic-v2/setup.py b/plugins/flytekit-pydantic-v2/setup.py new file mode 100644 index 0000000000..9063ca45c0 --- /dev/null +++ b/plugins/flytekit-pydantic-v2/setup.py @@ -0,0 +1,40 @@ +from setuptools import setup + +PLUGIN_NAME = "pydantic" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}-v2" + +plugin_requires = ["flytekit>1.13.7", "pydantic>=2.6.0"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Plugin adding type support for Pydantic models", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekitplugins-pydantic-v2", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.9", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-pydantic-v2/tests/test_pydantic_type_transformer.py b/plugins/flytekit-pydantic-v2/tests/test_pydantic_type_transformer.py new file mode 100644 index 0000000000..c8630bfb5d --- /dev/null +++ b/plugins/flytekit-pydantic-v2/tests/test_pydantic_type_transformer.py @@ -0,0 +1,477 @@ +import os +import tempfile +from dataclasses import field +from enum import Enum +from typing import Dict, List +from pydantic import BaseModel, Field + +import pytest +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct + +from flytekit import task, workflow +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.models.literals import Literal, Scalar +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile + +class Status(Enum): + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + + +@pytest.fixture +def local_dummy_file(): + fd, path = tempfile.mkstemp() + try: + with os.fdopen(fd, "w") as tmp: + tmp.write("Hello FlyteFile") + yield path + finally: + os.remove(path) + + +@pytest.fixture +def local_dummy_directory(): + temp_dir = tempfile.TemporaryDirectory() + try: + with open(os.path.join(temp_dir.name, "file"), "w") as tmp: + tmp.write("Hello FlyteDirectory") + yield temp_dir.name + finally: + temp_dir.cleanup() + + +def test_flytetypes_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): + class InnerDC(BaseModel): + flytefile: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + flytedir: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + + class DC(BaseModel): + flytefile: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + flytedir: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) + + @task + def t1(path: FlyteFile) -> FlyteFile: + return path + + @task + def t2(path: FlyteDirectory) -> FlyteDirectory: + return path + + @workflow + def wf(dc: DC) -> (FlyteFile, FlyteFile, FlyteDirectory, FlyteDirectory): + f1 = t1(path=dc.flytefile) + f2 = t1(path=dc.inner_dc.flytefile) + d1 = t2(path=dc.flytedir) + d2 = t2(path=dc.inner_dc.flytedir) + return f1, f2, d1, d2 + + o1, o2, o3, o4 = wf(dc=DC()) + with open(o1, "r") as fh: + assert fh.read() == "Hello FlyteFile" + + with open(o2, "r") as fh: + assert fh.read() == "Hello FlyteFile" + + with open(os.path.join(o3, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + + with open(os.path.join(o4, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + +def test_all_types_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): + class InnerDC(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file)]) + g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + m: dict = field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + enum_status: Status = field(default=Status.PENDING) + + class DC(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file), ]) + g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + m: dict = field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) + enum_status: Status = field(default=Status.PENDING) + + @task + def t_inner(inner_dc: InnerDC): + assert type(inner_dc) is InnerDC + + # f: List[FlyteFile] + for ff in inner_dc.f: + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_dc.j.items(): + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_dc.n) is FlyteFile + with open(inner_dc.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_dc.o) is FlyteDirectory + assert not inner_dc.o.downloaded + with open(os.path.join(inner_dc.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_dc.o.downloaded + + # enum: Status + assert inner_dc.enum_status == Status.PENDING + + + @task + def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]], + h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile], + k: Dict[int, List[int]], l: Dict[int, Dict[int, int]], m: dict, + n: FlyteFile, o: FlyteDirectory, enum_status: Status): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in + k.items()), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) and isinstance(v, dict) and all( + isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items()) + for k, v in l.items()), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + print("All attributes passed strict type checks.") + + @workflow + def wf(dc: DC): + t_inner(dc.inner_dc) + t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c, + d=dc.d, e=dc.e, f=dc.f, + g=dc.g, h=dc.h, i=dc.i, + j=dc.j, k=dc.k, l=dc.l, + m=dc.m, n=dc.n, o=dc.o, enum_status=dc.enum_status) + + t_test_all_attributes(a=dc.inner_dc.a, b=dc.inner_dc.b, c=dc.inner_dc.c, + d=dc.inner_dc.d, e=dc.inner_dc.e, f=dc.inner_dc.f, + g=dc.inner_dc.g, h=dc.inner_dc.h, i=dc.inner_dc.i, + j=dc.inner_dc.j, k=dc.inner_dc.k, l=dc.inner_dc.l, + m=dc.inner_dc.m, n=dc.inner_dc.n, o=dc.inner_dc.o, enum_status=dc.inner_dc.enum_status) + + wf(dc=DC()) + +def test_input_from_flyte_console_pydantic_basemodel(local_dummy_file, local_dummy_directory): + # Flyte Console will send the input data as protobuf Struct + + class InnerDC(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file)]) + g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + m: dict = field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + enum_status: Status = field(default=Status.PENDING) + + class DC(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file), ]) + g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + m: dict = field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) + enum_status: Status = field(default=Status.PENDING) + + @task + def t_inner(inner_dc: InnerDC): + assert type(inner_dc) is InnerDC + + # f: List[FlyteFile] + for ff in inner_dc.f: + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_dc.j.items(): + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_dc.n) is FlyteFile + with open(inner_dc.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_dc.o) is FlyteDirectory + assert not inner_dc.o.downloaded + with open(os.path.join(inner_dc.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_dc.o.downloaded + + # enum: Status + assert inner_dc.enum_status == Status.PENDING + + def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]], + h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile], + k: Dict[int, List[int]], l: Dict[int, Dict[int, int]], m: dict, + n: FlyteFile, o: FlyteDirectory, enum_status: Status): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in + k.items()), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) and isinstance(v, dict) and all( + isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items()) + for k, v in l.items()), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + print("All attributes passed strict type checks.") + + # This is the old dataclass serialization behavior. + # https://github.com/flyteorg/flytekit/blob/94786cfd4a5c2c3b23ac29dcd6f04d0553fa1beb/flytekit/core/type_engine.py#L702-L728 + dc = DC() + json_str = dc.model_dump_json() + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, DC) + t_inner(downstream_input.inner_dc) + t_test_all_attributes(a=downstream_input.a, b=downstream_input.b, c=downstream_input.c, + d=downstream_input.d, e=downstream_input.e, f=downstream_input.f, + g=downstream_input.g, h=downstream_input.h, i=downstream_input.i, + j=downstream_input.j, k=downstream_input.k, l=downstream_input.l, + m=downstream_input.m, n=downstream_input.n, o=downstream_input.o, + enum_status=downstream_input.enum_status) + t_test_all_attributes(a=downstream_input.inner_dc.a, b=downstream_input.inner_dc.b, c=downstream_input.inner_dc.c, + d=downstream_input.inner_dc.d, e=downstream_input.inner_dc.e, f=downstream_input.inner_dc.f, + g=downstream_input.inner_dc.g, h=downstream_input.inner_dc.h, i=downstream_input.inner_dc.i, + j=downstream_input.inner_dc.j, k=downstream_input.inner_dc.k, l=downstream_input.inner_dc.l, + m=downstream_input.inner_dc.m, n=downstream_input.inner_dc.n, o=downstream_input.inner_dc.o, + enum_status=downstream_input.inner_dc.enum_status) + +def test_dataclasss_in_pydantic_basemodel(): + from dataclasses import dataclass + @dataclass + class InnerDC: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + class DC(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_dc: InnerDC = Field(default_factory=lambda: InnerDC()) + + @task + def t_dc(dc: DC): + assert isinstance(dc, DC) + assert isinstance(dc.inner_dc, InnerDC) + + @task + def t_inner(inner_dc: InnerDC): + assert isinstance(inner_dc, InnerDC) + + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + @workflow + def wf(dc: DC): + t_dc(dc=dc) + t_inner(inner_dc=dc.inner_dc) + t_test_primitive_attributes(a=dc.a, b=dc.b, c=dc.c, d=dc.d) + t_test_primitive_attributes(a=dc.inner_dc.a, b=dc.inner_dc.b, c=dc.inner_dc.c, d=dc.inner_dc.d) + + dc = DC() + wf(dc=dc) + +def test_pydantic_dataclasss_in_pydantic_basemodel(): + from pydantic.dataclasses import dataclass + @dataclass + class InnerDC: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + class DC(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_dc: InnerDC = Field(default_factory=lambda: InnerDC()) + + @task + def t_dc(dc: DC): + assert isinstance(dc, DC) + assert isinstance(dc.inner_dc, InnerDC) + + @task + def t_inner(inner_dc: InnerDC): + assert isinstance(inner_dc, InnerDC) + + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + + @workflow + def wf(dc: DC): + t_dc(dc=dc) + t_inner(inner_dc=dc.inner_dc) + t_test_primitive_attributes(a=dc.a, b=dc.b, c=dc.c, d=dc.d) + t_test_primitive_attributes(a=dc.inner_dc.a, b=dc.inner_dc.b, c=dc.inner_dc.c, d=dc.inner_dc.d) + + dc = DC() + wf(dc=dc) diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index aa7e7dca4f..4eac1a1296 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -22,7 +22,10 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer - +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct +from flytekit.models.literals import Literal, Scalar +import json # Fixture that ensures a dummy local file @pytest.fixture @@ -364,3 +367,13 @@ def my_wf(path: SvgDirectory) -> DC: dc1 = my_wf(path=svg_directory) dc2 = DC(f=svg_directory) assert dc1 == dc2 + +def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_directory): + # Flyte Console will send the input data as protobuf Struct + + dict_obj = {"path": local_dummy_directory} + json_str = json.dumps(dict_obj) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, FlyteDirectory) + assert isinstance(downstream_input, FlyteDirectory) + assert downstream_input == FlyteDirectory(local_dummy_directory) diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 7e09e918ae..121c6f8a5c 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -1,3 +1,4 @@ +import json import os import pathlib import tempfile @@ -20,7 +21,9 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer - +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct +from flytekit.models.literals import Literal, Scalar # Fixture that ensures a dummy local file @pytest.fixture @@ -705,3 +708,12 @@ def test_new_remote_file(): nf = FlyteFile.new_remote_file(name="foo.txt") assert isinstance(nf, FlyteFile) assert nf.path.endswith('foo.txt') + +def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_file): + # Flyte Console will send the input data as protobuf Struct + + dict_obj = {"path": local_dummy_file} + json_str = json.dumps(dict_obj) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, FlyteFile) + assert downstream_input == FlyteFile(local_dummy_file) diff --git a/tests/flytekit/unit/core/test_type_engine_binary_idl.py b/tests/flytekit/unit/core/test_type_engine_binary_idl.py index 986fac0c7c..171b774360 100644 --- a/tests/flytekit/unit/core/test_type_engine_binary_idl.py +++ b/tests/flytekit/unit/core/test_type_engine_binary_idl.py @@ -627,7 +627,7 @@ def t_inner(inner_dc: InnerDC): with open(os.path.join(inner_dc.o, "file"), "r") as fh: assert fh.read() == "Hello FlyteDirectory" assert inner_dc.o.downloaded - print("Test InnerDC Successfully Passed") + # enum: Status assert inner_dc.enum_status == Status.PENDING @@ -690,8 +690,6 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li # Strict type check for Enum assert isinstance(enum_status, Status), "enum_status is not Status" - print("All attributes passed strict type checks.") - @workflow def wf(dc: DC): t_inner(dc.inner_dc) @@ -710,6 +708,8 @@ def wf(dc: DC): wf(dc=DC()) def test_backward_compatible_with_dataclass_in_protobuf_struct(local_dummy_file, local_dummy_directory): + # Flyte Console will send the input data as protobuf Struct + # This test also test how Flyte Console with attribute access on the Struct object @dataclass class InnerDC: @@ -777,7 +777,7 @@ def t_inner(inner_dc: InnerDC): with open(os.path.join(inner_dc.o, "file"), "r") as fh: assert fh.read() == "Hello FlyteDirectory" assert inner_dc.o.downloaded - print("Test InnerDC Successfully Passed") + # enum: Status assert inner_dc.enum_status == Status.PENDING @@ -838,8 +838,6 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li # Strict type check for Enum assert isinstance(enum_status, Status), "enum_status is not Status" - print("All attributes passed strict type checks.") - # This is the old dataclass serialization behavior. # https://github.com/flyteorg/flytekit/blob/94786cfd4a5c2c3b23ac29dcd6f04d0553fa1beb/flytekit/core/type_engine.py#L702-L728 dc = DC() diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py index f107384b96..9487c6f4c3 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py @@ -16,7 +16,6 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow -from flytekit.exceptions.user import FlyteAssertion from flytekit.lazy_import.lazy_module import is_imported from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata @@ -49,7 +48,6 @@ ) df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - def test_protocol(): assert get_protocol("s3://my-s3-bucket/file") == "s3" assert get_protocol("/file") == "file" @@ -57,8 +55,6 @@ def test_protocol(): def generate_pandas() -> pd.DataFrame: return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) - - def test_formats_make_sense(): @task def t1(a: pd.DataFrame) -> pd.DataFrame: