-
Notifications
You must be signed in to change notification settings - Fork 278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pydantic Transformer V2 #2792
base: master
Are you sure you want to change the base?
Pydantic Transformer V2 #2792
Conversation
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
if lv.scalar.primitive.float_value is not None: | ||
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.") | ||
return int(lv.scalar.primitive.float_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for cases when you input from the flyte console, and you use attribute access directly, you have to the float
to int
.
Since javascript has only number
, it can't tell the difference between int and float, and when goland (propeller) doing attribute access, it doesn't have the expected python type
class TrainConfig(BaseModel):
lr: float = 1e-3
batch_size: int = 32
@workflow
def wf(cfg: TrainConfig) -> TrainConfig:
return t_args(a=cfg.lr, batch_size=cfg.batch_size)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the javascript issue and the attribute access issue are orthogonal right?
this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
YES, the attribute access works well, it's because javascript pass float to golang, and golang pass float to python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?
Yes, but when you are accessing a simple type, you have to change the behavior of SimpleTransformer
.
For Pydantic Transformer, we will use strict=False
as argument to convert it to right type.
def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
if binary_idl_object.tag == MESSAGEPACK:
dict_obj = msgpack.loads(binary_idl_object.value)
python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
return python_val
@lukas503 |
Hi @Future-Outlier, Thanks for working on the Pydantic I've been testing the code locally and wondered about a behavior related to caching. Specifically, I’m curious if Here’s an example: from flytekit import task, workflow
from pydantic import BaseModel
class Config(BaseModel):
x: int = 1
# y: int = 4
@task(cache=True, cache_version="v1")
def task1(val: int) -> Config:
return Config()
@task(cache=True, cache_version="v1")
def task2(cfg: Config) -> Config:
print("CALLED!", cfg)
return cfg
@workflow
def my_workflow():
config = task1(val=5)
task2(cfg=config)
if __name__ == "__main__":
print(Config.model_json_schema())
my_workflow() When I run the workflow for the first time, nothing is cached. On the second run, the results are cached, as expected. However, if I uncomment Is this the expected behavior? Shouldn't schema changes like this invalidate the cache? |
good question, will test this out and ask other maintainers if I don't know what happened, thank you <3 |
@lukas503 |
Signed-off-by: Future-Outlier <[email protected]>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2792 +/- ##
==========================================
- Coverage 76.82% 74.83% -2.00%
==========================================
Files 196 196
Lines 20301 20331 +30
Branches 2610 2618 +8
==========================================
- Hits 15596 15214 -382
- Misses 4004 4364 +360
- Partials 701 753 +52 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
if lv.scalar: | ||
if lv.scalar.binary: | ||
return self.from_binary_idl(lv.scalar.binary, expected_python_type) | ||
if lv.scalar.generic: | ||
return self.from_generic_idl(lv.scalar.generic, expected_python_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class DC(BaseModel):
ff: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
@task(container_image=image)
def t_args(dc: DC) -> DC:
with open(dc.ff, "r") as f:
print(f.read())
return dc
@task(container_image=image)
def t_ff(ff: FlyteFile) -> FlyteFile:
with open(ff, "r") as f:
print(f.read())
return ff
@workflow
def wf(dc: DC) -> DC:
t_ff(dc.ff)
return t_args(dc=dc)
this is for this case input from flyteconsole
.
Thanks for updating the PR. I now understand the underlying issue better. It appears the caching mechanism is ignoring the output types/schema. What’s unclear to me is why the output types/schema aren’t factored into the hash used for caching. In my opinion, any interface change could invalidate the cache even the outputs. I don’t see how the old cached outputs can remain valid after an interface change. That said, this concern isn’t directly related to the current PR, so feel free to proceed as is. |
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
from pydantic import BaseModel as BaseModelV2 | ||
from pydantic.v1 import BaseModel as BaseModelV1 | ||
|
||
return issubclass(python_type, BaseModelV1) or issubclass(python_type, BaseModelV2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for backward compatible
FlyteDirToMultipartBlobTransformer | ||
TensorboardLogs | ||
TFRecordsDirectory | ||
""" | ||
|
||
import typing | ||
|
||
from .types import FlyteDirectory | ||
from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to import FlyteDirToMultipartBlobTransformer
in the pydantic plugin, we have to import here.
def from_generic_idl(self, generic: Struct, expected_python_type: typing.Type[FlyteDirectory]) -> FlyteDirectory: | ||
json_str = _json_format.MessageToJson(generic) | ||
python_val = json.loads(json_str) | ||
path = python_val.get("path", None) | ||
|
||
if path is None: | ||
raise ValueError("FlyteDirectory's path should not be None") | ||
|
||
return FlyteDirToMultipartBlobTransformer().to_python_value( | ||
FlyteContextManager.current_context(), | ||
Literal( | ||
scalar=Scalar( | ||
blob=Blob( | ||
metadata=BlobMetadata( | ||
type=_core_types.BlobType( | ||
format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART | ||
) | ||
), | ||
uri=path, | ||
) | ||
) | ||
), | ||
expected_python_type, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is literally the same as _deserialize
function for Structured Dataset
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Tracking issue
flyteorg/flyte#5033
flyteorg/flyte#5318
demo
Screen.Recording.2024-10-08.at.11.42.53.PM.mov
How to test it by others
not sure
Why are the changes needed?
why from_generic_idl?
flyteconsole/flytectl input to handle flytetypes
why _check_and_covert_int in the int transformer?
flyteconsole/flytectl input to handle float issue
why basemodel -> json str -> dict obj -> msgpack bytes?
for enum cases
What changes were proposed in this pull request?
note: we don't support pydantic BaseModel has a dataclass with FlyteTypes.
We support pydantic BaseModel has a dataclass with primitive types.
How was this patch tested?
Example code.
(nested cases, flyte types and attribute access.)
Setup process
Screenshots
Check all the applicable boxes
Related PRs
Docs link