diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 1d0165751f..f5c3fb5a3c 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -29,6 +29,7 @@ from flytekit.core.node import Node from flytekit.core.promise import NodeOutput, Promise, VoidPromise from flytekit.core.resources import Resources +from flytekit.types.pickle.pickle import FlytePickleTransformer from flytekit.core.task import TaskMetadata, task from flytekit.core.testing import patch, task_mock from flytekit.core.type_engine import RestrictedTypeError, SimpleTransformer, TypeEngine, TypeTransformerFailedError @@ -115,6 +116,37 @@ def my_task(a: int): assert context_manager.FlyteContextManager.size() == 1 +def test_transformer_override(): + tf = FlytePickleTransformer() + + @task + def my_task() -> Annotated[str, tf]: + return "Hello world" + + @workflow + def wf() -> FlyteFile: + return my_task() + + annotated_output = wf() + + @task + def my_task_any() -> typing.Any: + return "Hello world" + + @workflow + def wf_any() -> FlyteFile: + return my_task_any() + + any_output = wf_any() + + with open(annotated_output, "rb") as fh: + contents_annotated = fh.read() + + with open(any_output, "rb") as fh: + contents_any = fh.read() + assert contents_annotated == contents_any + + def test_single_output(): @task def my_task() -> str: