Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic Transformer V2 #2792

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,17 @@ def _check_and_covert_float(lv: Literal) -> float:
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to float")


def _check_and_covert_int(lv: Literal) -> int:
if lv.scalar.primitive.integer is not None:
return lv.scalar.primitive.integer

if lv.scalar.primitive.float_value is not None:
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.")
return int(lv.scalar.primitive.float_value)
Comment on lines +2272 to +2274
Copy link
Member Author

@Future-Outlier Future-Outlier Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for cases when you input from the flyte console, and you use attribute access directly, you have to the float to int.
Since javascript has only number, it can't tell the difference between int and float, and when goland (propeller) doing attribute access, it doesn't have the expected python type

class TrainConfig(BaseModel):
    lr: float = 1e-3
    batch_size: int = 32

@workflow
def wf(cfg: TrainConfig) -> TrainConfig:
    return t_args(a=cfg.lr, batch_size=cfg.batch_size)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the javascript issue and the attribute access issue are orthogonal right?

this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YES, the attribute access works well, it's because javascript pass float to golang, and golang pass float to python.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?

Yes, but when you are accessing a simple type, you have to change the behavior of SimpleTransformer.

For Pydantic Transformer, we will use strict=False as argument to convert it to right type.

    def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
        if binary_idl_object.tag == MESSAGEPACK:
            dict_obj = msgpack.loads(binary_idl_object.value)
            python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
            return python_val


raise TypeTransformerFailedError(f"Cannot convert literal {lv} to int")


def _check_and_convert_void(lv: Literal) -> None:
if lv.scalar.none_type is None:
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to None")
Expand All @@ -2065,7 +2076,7 @@ def _register_default_type_transformers():
int,
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER),
lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))),
lambda x: x.scalar.primitive.integer,
_check_and_covert_int,
)
)

Expand Down
5 changes: 4 additions & 1 deletion flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def is_pydantic_basemodel(python_type: typing.Type) -> bool:
return False
else:
try:
from pydantic.v1 import BaseModel
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1

return issubclass(python_type, BaseModelV1) or issubclass(python_type, BaseModelV2)
Comment on lines +39 to +42
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for backward compatible

except ImportError:
from pydantic import BaseModel

Expand Down
3 changes: 2 additions & 1 deletion flytekit/types/directory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
:template: file_types.rst

FlyteDirectory
FlyteDirToMultipartBlobTransformer
TensorboardLogs
TFRecordsDirectory
"""

import typing

from .types import FlyteDirectory
from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
Comment on lines +13 to +20
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to import FlyteDirToMultipartBlobTransformer in the pydantic plugin, we have to import here.


# The following section provides some predefined aliases for commonly used FlyteDirectory formats.

Expand Down
3 changes: 2 additions & 1 deletion flytekit/types/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
:template: file_types.rst

FlyteFile
FlyteFilePathTransformer
HDF5EncodedFile
HTMLPage
JoblibSerializedFile
Expand All @@ -25,7 +26,7 @@

from typing_extensions import Annotated, get_args, get_origin

from .file import FlyteFile
from .file import FlyteFile, FlyteFilePathTransformer


class FileExt:
Expand Down
1 change: 1 addition & 0 deletions flytekit/types/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .types import (
FlyteSchema,
FlyteSchemaTransformer,
LocalIOSchemaReader,
LocalIOSchemaWriter,
SchemaEngine,
Expand Down
11 changes: 8 additions & 3 deletions flytekit/types/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
:template: custom.rst
:toctree: generated/

StructuredDataset
StructuredDatasetEncoder
StructuredDatasetDecoder
StructuredDataset
StructuredDatasetDecoder
StructuredDatasetEncoder
StructuredDatasetMetadata
StructuredDatasetTransformerEngine
StructuredDatasetType
"""

from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer
Expand All @@ -19,7 +22,9 @@
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetMetadata,
StructuredDatasetTransformerEngine,
StructuredDatasetType,
)


Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-pydantic-v2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TMP
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset

from . import transformer
from .custom import (
deserialize_flyte_dir,
deserialize_flyte_file,
deserialize_flyte_schema,
deserialize_structured_dataset,
serialize_flyte_dir,
serialize_flyte_file,
serialize_flyte_schema,
serialize_structured_dataset,
)

setattr(FlyteFile, "serialize_flyte_file", serialize_flyte_file)
setattr(FlyteFile, "deserialize_flyte_file", deserialize_flyte_file)
setattr(FlyteDirectory, "serialize_flyte_dir", serialize_flyte_dir)
setattr(FlyteDirectory, "deserialize_flyte_dir", deserialize_flyte_dir)
setattr(FlyteSchema, "serialize_flyte_schema", serialize_flyte_schema)
setattr(FlyteSchema, "deserialize_flyte_schema", deserialize_flyte_schema)
setattr(StructuredDataset, "serialize_structured_dataset", serialize_structured_dataset)
setattr(StructuredDataset, "deserialize_structured_dataset", deserialize_structured_dataset)
130 changes: 130 additions & 0 deletions plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/v2/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import Dict

from flytekit.core.context_manager import FlyteContextManager
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Schema
from flytekit.types.directory import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from flytekit.types.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.schema import FlyteSchema, FlyteSchemaTransformer
from flytekit.types.structured import (
StructuredDataset,
StructuredDatasetMetadata,
StructuredDatasetTransformerEngine,
StructuredDatasetType,
)
from pydantic import model_serializer, model_validator


@model_serializer
def serialize_flyte_file(self) -> Dict[str, str]:
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"path": lv.scalar.blob.uri}


@model_validator(mode="after")
def deserialize_flyte_file(self) -> FlyteFile:
pv = FlyteFilePathTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
uri=self.path,
)
)
),
type(self),
)
pv._remote_path = None
return pv


@model_serializer
def serialize_flyte_dir(self) -> Dict[str, str]:
lv = FlyteDirToMultipartBlobTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"path": lv.scalar.blob.uri}


@model_validator(mode="after")
def deserialize_flyte_dir(self) -> FlyteDirectory:
pv = FlyteDirToMultipartBlobTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
)
),
uri=self.path,
)
)
),
type(self),
)
pv._remote_directory = None
return pv


@model_serializer
def serialize_flyte_schema(self) -> Dict[str, str]:
FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"remote_path": self.remote_path}


@model_validator(mode="after")
def deserialize_flyte_schema(self) -> FlyteSchema:
# If we call the method to_python_value, FlyteSchemaTransformer will overwrite the local_path,
# which will lose our data.
# If this data is from an existed FlyteSchema, local path will be None.

if hasattr(self, "_local_path"):
return self

t = FlyteSchemaTransformer()
return t.to_python_value(
FlyteContextManager.current_context(),
Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))),
type(self),
)


@model_serializer
def serialize_structured_dataset(self) -> Dict[str, str]:
lv = StructuredDatasetTransformerEngine().to_literal(FlyteContextManager.current_context(), self, type(self), None)
sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri)
sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
return {
"uri": sd.uri,
"file_format": sd.file_format,
}


@model_validator(mode="after")
def deserialize_structured_dataset(self) -> StructuredDataset:
# If we call the method to_python_value, StructuredDatasetTransformerEngine will overwrite the 'dataframe',
# which will lose our data.
# If this data is from an existed StructuredDataset, dataframe will be None.

if hasattr(self, "dataframe"):
return self

return StructuredDatasetTransformerEngine().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=self.file_format)
),
uri=self.uri,
)
)
),
type(self),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
from typing import Type

import msgpack
from google.protobuf import json_format as _json_format

from flytekit import FlyteContext
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.loggers import logger
from flytekit.models import types
from flytekit.models.literals import Binary, Literal, Scalar
from flytekit.models.types import LiteralType, TypeStructure
from pydantic import BaseModel


class PydanticTransformer(TypeTransformer[BaseModel]):
def __init__(self):
super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False)

def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
schema = t.model_json_schema()
literal_type = {}
fields = t.__annotations__.items()

for name, python_type in fields:
try:
literal_type[name] = TypeEngine.to_literal_type(python_type)
except Exception as e:
logger.warning(
"Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e)
)

ts = TypeStructure(tag="", dataclass_type=literal_type)

return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts)

def to_literal(
self,
ctx: FlyteContext,
python_val: BaseModel,
python_type: Type[BaseModel],
expected: types.LiteralType,
) -> Literal:
dict_obj = python_val.model_dump()
msgpack_bytes = msgpack.dumps(dict_obj)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
if binary_idl_object.tag == MESSAGEPACK:
dict_obj = msgpack.loads(binary_idl_object.value)
python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
return python_val
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel:
"""
There will have 2 kinds of literal values:
1. protobuf Struct (From Flyte Console)
2. binary scalar (Others)
Hence we have to handle 2 kinds of cases.
"""
if lv and lv.scalar and lv.scalar.binary is not None:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

json_str = _json_format.MessageToJson(lv.scalar.generic)
dict_obj = json.loads(json_str)
python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
return python_val


TypeEngine.register(PydanticTransformer())
40 changes: 40 additions & 0 deletions plugins/flytekit-pydantic-v2/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from setuptools import setup

PLUGIN_NAME = "pydantic"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}-v2"

plugin_requires = ["flytekit>1.13.7", "pydantic>=2.9.2"]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="[email protected]",
description="Plugin adding type support for Pydantic models",
url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekitplugins-pydantic-v2",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}.v2"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.9",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}.v2=flytekitplugins.{PLUGIN_NAME}.v2"]},
)
Loading