Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic Transformer V2 #2792

Open
wants to merge 20 commits into
base: master
Choose a base branch
from

Conversation

Future-Outlier
Copy link
Member

@Future-Outlier Future-Outlier commented Oct 8, 2024

Tracking issue

flyteorg/flyte#5033
flyteorg/flyte#5318

demo

Screen.Recording.2024-10-08.at.11.42.53.PM.mov

How to test it by others

  1. git clone https://github.com/flyteorg/flytekit
  2. gh pr checkout 2792
  3. make setup-global-uv
  4. cd plugins/flytekit-pydantic-v2 && pip install -e .
  5. test a workflow example

not sure

  1. the file tree structure.

Why are the changes needed?

  1. why from_generic_idl?
    flyteconsole/flytectl input to handle flytetypes

  2. why _check_and_covert_int in the int transformer?
    flyteconsole/flytectl input to handle float issue

  3. why basemodel -> json str -> dict obj -> msgpack bytes?
    for enum cases

What changes were proposed in this pull request?

  • attribute access (primitives and flyte types) (datetime not sure)
  • flyte types
  • nested cases
  • dataclasses.dataclass in pydantic.BaseModel
  • pydantic.dataclass in pydantic.BaseModel
  • pydantic.BaseModel in pydantic.BaseModel

note: we don't support pydantic BaseModel has a dataclass with FlyteTypes.
We support pydantic BaseModel has a dataclass with primitive types.

How was this patch tested?

Example code.
(nested cases, flyte types and attribute access.)

from pydantic import BaseModel, Field
import os
from typing import Dict, List
from flytekit.types.file import FlyteFile
from flytekit.types.directory import FlyteDirectory
from flytekit import task, workflow, ImageSpec
from enum import Enum


flytekit_hash = "afd43443acd541e8336db5335de364b5d23cfff6"
flytekit = f"git+https://github.com/flyteorg/flytekit.git@{flytekit_hash}"
pydantic_plugin = f"git+https://github.com/flyteorg/flytekit.git@{flytekit_hash}#subdirectory=plugins/flytekit-pydantic-v2"

# Define custom image for the task
image = ImageSpec(packages=[flytekit, pydantic_plugin],
                            apt_packages=["git"],
                            registry="localhost:30000",
                         )

class Status(Enum):
    PENDING = "pending"
    APPROVED = "approved"
    REJECTED = "rejected"

class InnerDC(BaseModel):
    a: int = -1
    b: float = 2.1
    c: str = "Hello, Flyte"
    d: bool = False
    e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2])
    f: List[FlyteFile] = Field(default_factory=lambda: [FlyteFile("s3://my-s3-bucket/example.txt")])
    g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]])
    h: List[Dict[int, bool]] = Field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}])
    i: Dict[int, bool] = Field(default_factory=lambda: {0: False, 1: True, -1: False})
    j: Dict[int, FlyteFile] = Field(default_factory=lambda: {0: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             1: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             -1: FlyteFile("s3://my-s3-bucket/example.txt")})
    k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]})
    l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}})
    m: dict = Field(default_factory=lambda: {"key": "value"})
    n: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
    o: FlyteDirectory = Field(default_factory=lambda: FlyteDirectory("s3://my-s3-bucket/s3_flyte_dir"))
    enum_status: Status = Status.PENDING


class DC(BaseModel):
    a: int = -1
    b: float = 2.1
    c: str = "Hello, Flyte"
    d: bool = False
    e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2])
    f: List[FlyteFile] = Field(default_factory=lambda: [FlyteFile("s3://my-s3-bucket/example.txt")])
    g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]])
    h: List[Dict[int, bool]] = Field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}])
    i: Dict[int, bool] = Field(default_factory=lambda: {0: False, 1: True, -1: False})
    j: Dict[int, FlyteFile] = Field(default_factory=lambda: {0: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             1: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             -1: FlyteFile("s3://my-s3-bucket/example.txt")})
    k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]})
    l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}})
    m: dict = Field(default_factory=lambda: {"key": "value"})
    n: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
    o: FlyteDirectory = Field(default_factory=lambda: FlyteDirectory("s3://my-s3-bucket/s3_flyte_dir"))
    inner_dc: InnerDC = Field(default_factory=lambda: InnerDC())
    enum_status: Status = Status.PENDING


@task(container_image=image)
def t_dc(dc: DC) -> DC:
    return dc
@task(container_image=image)
def t_inner(inner_dc: InnerDC):
    assert isinstance(inner_dc, InnerDC)

    expected_file_content = "Default content"

    # f: List[FlyteFile]
    for ff in inner_dc.f:
        assert isinstance(ff, FlyteFile)
        with open(ff, "r") as f:
            assert f.read() == expected_file_content
    # j: Dict[int, FlyteFile]
    for _, ff in inner_dc.j.items():
        assert isinstance(ff, FlyteFile)
        with open(ff, "r") as f:
            assert f.read() == expected_file_content
    # n: FlyteFile
    assert isinstance(inner_dc.n, FlyteFile)
    with open(inner_dc.n, "r") as f:
        assert f.read() == expected_file_content
    # o: FlyteDirectory
    assert isinstance(inner_dc.o, FlyteDirectory)
    assert not inner_dc.o.downloaded
    with open(os.path.join(inner_dc.o, "example.txt"), "r") as fh:
        assert fh.read() == expected_file_content
    assert inner_dc.o.downloaded
    print("Test InnerDC Successfully Passed")
    # enum: Status
    assert inner_dc.enum_status == Status.PENDING


@task(container_image=image)
def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]],
                          h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile],
                          k: Dict[int, List[int]], l: Dict[int, Dict[int, int]], m: dict,
                          n: FlyteFile, o: FlyteDirectory,
                          enum_status: Status
                          ):
    # Strict type checks for simple types
    assert isinstance(a, int), f"a is not int, it's {type(a)}"
    assert a == -1
    assert isinstance(b, float), f"b is not float, it's {type(b)}"
    assert isinstance(c, str), f"c is not str, it's {type(c)}"
    assert isinstance(d, bool), f"d is not bool, it's {type(d)}"

    # Strict type checks for List[int]
    assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]"

    # Strict type checks for List[FlyteFile]
    assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]"

    # Strict type checks for List[List[int]]
    assert isinstance(g, list) and all(
        isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]"

    # Strict type checks for List[Dict[int, bool]]
    assert isinstance(h, list) and all(
        isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h
    ), "h is not List[Dict[int, bool]]"

    # Strict type checks for Dict[int, bool]
    assert isinstance(i, dict) and all(
        isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]"

    # Strict type checks for Dict[int, FlyteFile]
    assert isinstance(j, dict) and all(
        isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]"

    # Strict type checks for Dict[int, List[int]]
    assert isinstance(k, dict) and all(
        isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in
        k.items()), "k is not Dict[int, List[int]]"

    # Strict type checks for Dict[int, Dict[int, int]]
    assert isinstance(l, dict) and all(
        isinstance(k, int) and isinstance(v, dict) and all(
            isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items())
        for k, v in l.items()), "l is not Dict[int, Dict[int, int]]"

    # Strict type check for a generic dict
    assert isinstance(m, dict), "m is not dict"

    # Strict type check for FlyteFile
    assert isinstance(n, FlyteFile), "n is not FlyteFile"

    # Strict type check for FlyteDirectory
    assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory"

    # Strict type check for Enum
    assert isinstance(enum_status, Status), "enum_status is not Status"

    print("All attributes passed strict type checks.")


@workflow
def wf(dc: DC):
    t_dc(dc=dc)
    t_inner(inner_dc=dc.inner_dc)
    t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c,
                          d=dc.d, e=dc.e, f=dc.f,
                          g=dc.g, h=dc.h, i=dc.i,
                          j=dc.j, k=dc.k, l=dc.l,
                          m=dc.m, n=dc.n, o=dc.o,
                          enum_status=dc.enum_status
                          )

    t_test_all_attributes(a=dc.inner_dc.a, b=dc.inner_dc.b, c=dc.inner_dc.c,
                          d=dc.inner_dc.d, e=dc.inner_dc.e, f=dc.inner_dc.f,
                          g=dc.inner_dc.g, h=dc.inner_dc.h, i=dc.inner_dc.i,
                          j=dc.inner_dc.j, k=dc.inner_dc.k, l=dc.inner_dc.l,
                          m=dc.inner_dc.m, n=dc.inner_dc.n, o=dc.inner_dc.o,
                          enum_status=dc.inner_dc.enum_status
                          )

if __name__ == "__main__":
    from flytekit.clis.sdk_in_container import pyflyte
    from click.testing import CliRunner
    import os

    runner = CliRunner()
    path = os.path.realpath(__file__)
    input_val = '{"a": -1, "b": 3.14}'
    result = runner.invoke(pyflyte.main,
                           ["run", path, "wf", "--dc", input_val])
    print("Local Execution: ", result.output)
    #
    result = runner.invoke(pyflyte.main,
                           ["run", "--remote", path, "wf", "--dc", input_val])
    print("Remote Execution: ", result.output)
image image

Setup process

Screenshots

Check all the applicable boxes

  • I updated the documentation accordingly.
  • All new and existing tests passed.
  • All commits are signed-off.

Related PRs

Docs link

Signed-off-by: Future-Outlier <[email protected]>
@Future-Outlier Future-Outlier changed the title Pydantic Transformer V2 [wip] Pydantic Transformer V2 Oct 8, 2024
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Comment on lines +2059 to +2061
if lv.scalar.primitive.float_value is not None:
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.")
return int(lv.scalar.primitive.float_value)
Copy link
Member Author

@Future-Outlier Future-Outlier Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for cases when you input from the flyte console, and you use attribute access directly, you have to the float to int.
Since javascript has only number, it can't tell the difference between int and float, and when goland (propeller) doing attribute access, it doesn't have the expected python type

class TrainConfig(BaseModel):
    lr: float = 1e-3
    batch_size: int = 32

@workflow
def wf(cfg: TrainConfig) -> TrainConfig:
    return t_args(a=cfg.lr, batch_size=cfg.batch_size)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the javascript issue and the attribute access issue are orthogonal right?

this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YES, the attribute access works well, it's because javascript pass float to golang, and golang pass float to python.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?

Yes, but when you are accessing a simple type, you have to change the behavior of SimpleTransformer.

For Pydantic Transformer, we will use strict=False as argument to convert it to right type.

    def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
        if binary_idl_object.tag == MESSAGEPACK:
            dict_obj = msgpack.loads(binary_idl_object.value)
            python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
            return python_val

@Future-Outlier
Copy link
Member Author

@lukas503
Hi, I saw you add an emoji to this PR!
Do you want to help me test this out?
Search "How to test it by others?" will have a guide for you!

@lukas503
Copy link

lukas503 commented Oct 8, 2024

Hi @Future-Outlier,

Thanks for working on the Pydantic TypeTransformer! Which "How to test it by others?" guide are you referring to?

I've been testing the code locally and wondered about a behavior related to caching. Specifically, I’m curious if model_json_schema is considered in the hash used for caching.

Here’s an example:

from flytekit import task, workflow
from pydantic import BaseModel

class Config(BaseModel):
    x: int = 1
    # y: int = 4

@task(cache=True, cache_version="v1")
def task1(val: int) -> Config:
    return Config()

@task(cache=True, cache_version="v1")
def task2(cfg: Config) -> Config:
    print("CALLED!", cfg)
    return cfg

@workflow
def my_workflow():
    config = task1(val=5)
    task2(cfg=config)

if __name__ == "__main__":
    print(Config.model_json_schema())
    my_workflow()

When I run the workflow for the first time, nothing is cached. On the second run, the results are cached, as expected. However, if I uncomment y: int = 4, the tasks still remain cached. I would assume that this schema change would trigger a cache bust and re-execute the tasks. This causes failure if I update the attributes and the cache_version of task2.

Is this the expected behavior? Shouldn't schema changes like this invalidate the cache?

@Future-Outlier
Copy link
Member Author

Hi @Future-Outlier,

Thanks for working on the Pydantic TypeTransformer! Which "How to test it by others?" guide are you referring to?

I've been testing the code locally and wondered about a behavior related to caching. Specifically, I’m curious if model_json_schema is considered in the hash used for caching.

Here’s an example:

from flytekit import task, workflow
from pydantic import BaseModel

class Config(BaseModel):
    x: int = 1
    # y: int = 4

@task(cache=True, cache_version="v1")
def task1(val: int) -> Config:
    return Config()

@task(cache=True, cache_version="v1")
def task2(cfg: Config) -> Config:
    print("CALLED!", cfg)
    return cfg

@workflow
def my_workflow():
    config = task1(val=5)
    task2(cfg=config)

if __name__ == "__main__":
    print(Config.model_json_schema())
    my_workflow()

When I run the workflow for the first time, nothing is cached. On the second run, the results are cached, as expected. However, if I uncomment y: int = 4, the tasks still remain cached. I would assume that this schema change would trigger a cache bust and re-execute the tasks. This causes failure if I update the attributes and the cache_version of task2.

Is this the expected behavior? Shouldn't schema changes like this invalidate the cache?

good question, will test this out and ask other maintainers if I don't know what happened, thank you <3

@Future-Outlier
Copy link
Member Author

Hi @Future-Outlier,

Thanks for working on the Pydantic TypeTransformer! Which "How to test it by others?" guide are you referring to?

I've been testing the code locally and wondered about a behavior related to caching. Specifically, I’m curious if model_json_schema is considered in the hash used for caching.

Here’s an example:

from flytekit import task, workflow
from pydantic import BaseModel

class Config(BaseModel):
    x: int = 1
    # y: int = 4

@task(cache=True, cache_version="v1")
def task1(val: int) -> Config:
    return Config()

@task(cache=True, cache_version="v1")
def task2(cfg: Config) -> Config:
    print("CALLED!", cfg)
    return cfg

@workflow
def my_workflow():
    config = task1(val=5)
    task2(cfg=config)

if __name__ == "__main__":
    print(Config.model_json_schema())
    my_workflow()

When I run the workflow for the first time, nothing is cached. On the second run, the results are cached, as expected. However, if I uncomment y: int = 4, the tasks still remain cached. I would assume that this schema change would trigger a cache bust and re-execute the tasks. This causes failure if I update the attributes and the cache_version of task2.

Is this the expected behavior? Shouldn't schema changes like this invalidate the cache?

@lukas503
sorry can you try again?
I've updated the above description.

Copy link

codecov bot commented Oct 9, 2024

Codecov Report

Attention: Patch coverage is 33.33333% with 50 lines in your changes missing coverage. Please review.

Project coverage is 74.83%. Comparing base (cd8216a) to head (773a3b6).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
flytekit/types/schema/types.py 25.00% 9 Missing and 3 partials ⚠️
flytekit/types/structured/structured_dataset.py 25.00% 9 Missing and 3 partials ⚠️
flytekit/types/directory/types.py 40.00% 7 Missing and 2 partials ⚠️
flytekit/types/file/file.py 40.00% 7 Missing and 2 partials ⚠️
flytekit/core/type_engine.py 37.50% 4 Missing and 1 partial ⚠️
flytekit/interaction/click_types.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2792      +/-   ##
==========================================
- Coverage   76.82%   74.83%   -2.00%     
==========================================
  Files         196      196              
  Lines       20301    20331      +30     
  Branches     2610     2618       +8     
==========================================
- Hits        15596    15214     -382     
- Misses       4004     4364     +360     
- Partials      701      753      +52     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Comment on lines +586 to +590
if lv.scalar:
if lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)
if lv.scalar.generic:
return self.from_generic_idl(lv.scalar.generic, expected_python_type)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class DC(BaseModel):
    ff: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))

@task(container_image=image)
def t_args(dc: DC) -> DC:
    with open(dc.ff, "r") as f:
        print(f.read())
    return dc

@task(container_image=image)
def t_ff(ff: FlyteFile) -> FlyteFile:
    with open(ff, "r") as f:
        print(f.read())
    return ff

@workflow
def wf(dc: DC) -> DC:
    t_ff(dc.ff)
    return t_args(dc=dc)

this is for this case input from flyteconsole.

@lukas503
Copy link

lukas503 commented Oct 9, 2024

sorry can you try again?
I've updated the above description.

Thanks for updating the PR. I now understand the underlying issue better. It appears the caching mechanism is ignoring the output types/schema. What’s unclear to me is why the output types/schema aren’t factored into the hash used for caching. In my opinion, any interface change could invalidate the cache even the outputs. I don’t see how the old cached outputs can remain valid after an interface change.

That said, this concern isn’t directly related to the current PR, so feel free to proceed as is.

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
@Future-Outlier Future-Outlier changed the title [wip] Pydantic Transformer V2 Pydantic Transformer V2 Oct 10, 2024
Comment on lines +39 to +42
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1

return issubclass(python_type, BaseModelV1) or issubclass(python_type, BaseModelV2)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for backward compatible

Comment on lines +13 to +20
FlyteDirToMultipartBlobTransformer
TensorboardLogs
TFRecordsDirectory
"""

import typing

from .types import FlyteDirectory
from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to import FlyteDirToMultipartBlobTransformer in the pydantic plugin, we have to import here.

Comment on lines +541 to +564
def from_generic_idl(self, generic: Struct, expected_python_type: typing.Type[FlyteDirectory]) -> FlyteDirectory:
json_str = _json_format.MessageToJson(generic)
python_val = json.loads(json_str)
path = python_val.get("path", None)

if path is None:
raise ValueError("FlyteDirectory's path should not be None")

return FlyteDirToMultipartBlobTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
)
),
uri=path,
)
)
),
expected_python_type,
)
Copy link
Member Author

@Future-Outlier Future-Outlier Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is literally the same as _deserialize function for Structured Dataset

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants