Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Dynamic workflows break caching of artifacts offloaded to blob storage #3842

Open
2 tasks done
fg91 opened this issue Jul 6, 2023 · 7 comments
Open
2 tasks done
Assignees
Labels
backlogged For internal use. Reserved for contributor team workflow. bug Something isn't working flytekit FlyteKit Python related issue stale

Comments

@fg91
Copy link
Member

fg91 commented Jul 6, 2023

Describe the bug

Some type transformers offload artifacts to blob storage as follows (e.g. here):

def to_literal(...) -> Literal:
        local_path = ...
        # Save object to local path

        remote_path = ctx.file_access.get_random_remote_path(local_path)
        ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
        # Return Literal containing remote_path

When objects of such types are passed to @dynamic workflows and then passed along to tasks called within the dynamic workflow, this behaviour always leads to cache misses. The reason is that in the dynamic workflow, the objects are deserialized and then again serialized to a different random remote path.

Expected behavior

There should not be cache misses in this situation.

Additional context to reproduce

Let us consider this example workflow:

import torch.nn as nn
from flytekit import task, workflow, dynamic


@task(cache=True, cache_version="0.1")
def train(model: nn.Module) -> nn.Module:
    print(f"Training model {model}")
    return model


@task(cache=True, cache_version="0.1")
def other_task(param: int) -> int:
    print(f"Doing something else with param {param}")
    return param


@dynamic(cache=True, cache_version="0.1")
def sub_wf(model: nn.Module, param: int) -> tuple[nn.Module, int]:
    other_task(param=param)
    train(model=model)

    return model, param


@task(cache=True, cache_version="0.1")
def create_model() -> nn.Module:
    return nn.Linear(1, 1)


@workflow
def wf(param: int = 1):
    model = create_model()
    sub_wf(model=model, param=param)

Screenshots

The first execution with param=1 results in cache puts for all tasks:

Screenshot 2023-07-06 at 15 57 48

Next, let's re-run the workflow but with param=2:

Screenshot 2023-07-06 at 16 09 30

It is expected that other_task has a cache miss since we changed param. However, since create_model had a cache hit, so should have train.

Instead, one can observe that the output of create_model (retrieved from cache) ...

Screenshot 2023-07-06 at 16 09 44

... is not the same is as the input to train:

Screenshot 2023-07-06 at 16 09 55

The reason is that the respective type transformer deserialized the artifact and serialized it again to a new random bucket path.

I don't know how exactly this could be solved but the type engine should "somehow detect that it is being executed in a dynamic workflow" and should "simply forward the artifacts instead of serializing them to a different location".

Are you sure this issue hasn't been raised already?

  • Yes

Have you read the Code of Conduct?

  • Yes
@fg91 fg91 added bug Something isn't working untriaged This issues has not yet been looked at by the Maintainers labels Jul 6, 2023
@eapolinario eapolinario added flytekit FlyteKit Python related issue and removed untriaged This issues has not yet been looked at by the Maintainers labels Jul 7, 2023
@Tom-Newton
Copy link
Contributor

I've got the same problem. I raised it on the flyte slack so there is a bit of relevant discussion there https://flyte-org.slack.com/archives/CP2HDHKE1/p1697208167873249

@eapolinario eapolinario added the backlogged For internal use. Reserved for contributor team workflow. label Oct 16, 2023
@kumare3
Copy link
Contributor

kumare3 commented Oct 16, 2023

Reason for the bug - #4246. Lets track that as the master issue

@wild-endeavor
Copy link
Contributor

I kinda feel like this issue is a bit different than #4246, but maybe I just haven't dug deeply enough - I haven't traced through all the links in that other issue.

@fg91 did you try using HashMethod at all (used in the other ticket)? Something like

import typing
import torch.nn as nn
from typing_extensions import Annotated
from flytekit import task, workflow, dynamic, HashMethod


def hash_model(m: nn.Module) -> str:
    return str(hash(m))


Model = Annotated[nn.Module, HashMethod(hash_model)]


@task(cache=True, cache_version="0.1")
def train(model: nn.Module) -> nn.Module:
    print(f"Training model {model}")
    return model


@task(cache=True, cache_version="0.1")
def other_task(param: int) -> int:
    print(f"Doing something else with param {param}")
    return param


@dynamic(cache=True, cache_version="0.1")
def sub_wf(model: nn.Module, param: int) -> typing.Tuple[nn.Module, int]:
    other_task(param=param)
    train(model=model)

    return model, param


@task(cache=True, cache_version="0.1")
def create_model() -> Model:
    return nn.Linear(1, 1)


@workflow
def wf(param: int = 1):
    model = create_model()
    sub_wf(model=model, param=param)

You'd have to replace the hash with something real of course. The python hash function has too much entropy in it. I'm not sure if I agree with the goal actually. Somehow detect that it's passed through and not upload again. How would flytekit distinguish between these two

# What you had above
@dynamic(cache=True, cache_version="0.1")
def sub_wf(model: nn.Module, param: int) -> tuple[nn.Module, int]:
    other_task(param=param)
    train(model=model)

    return model, param
@dynamic(cache=True, cache_version="0.1")
def sub_wf(model: nn.Module, param: int) -> tuple[nn.Module, int]:
    model = update_something(model)  # make a change to the model
    other_task(param=param)
    train(model=model)

    return model, param

With the HashMethod change, what should happen is yes, the model does get downloaded, but it should get uploaded again with that same hash, so you do still incur the download/rehash/upload. But the subsequent downstream train task shouldn't re-run. @fg91 would you mind trying out the hash method to see if that at least addresses the current problem?

The difference I feel vs the other ticket is that in the other ticket:

  • In df = load_wine(as_frame=True).frame thedf is coming from within the body of the dynamic task, not passed as an input.
  • The load_wine function is not a flytekit task. It's just a function.
    But perhaps these differences don't matter. will need to think more.

As always thank you for the screenshots and very clear explanation. Super helpful.

@Tom-Newton
Copy link
Contributor

I think you are right @wild-endeavor. Sorry for confusing things. So probably the solution to this issue is to use HashMethod at which point you might run into #4246.

@fg91
Copy link
Member Author

fg91 commented Nov 1, 2023

@fg91 did you try using HashMethod at all (used in the other ticket)? [...]
I'm not sure if I agree with the goal actually. Somehow detect that it's passed through and not upload again. How would flytekit distinguish between these two

@dynamic(cache=True, cache_version="0.1")
def sub_wf(model: nn.Module, param: int) -> tuple[nn.Module, int]:
    model = update_something(model)  # make a change to the model
    ...

Hey @wild-endeavor,

sorry for the late reply. I wasn't aware of the HashMethod mechanism in fact. I agree with your concern that flytekit wouldn't be able to "pass the model through without re-upload" in the dynamic workflow and detect whether it was changed in the meantime.

To test whether the proposal works and for simplicity, I forced the same hash

def hash_model(m: nn.Module) -> str:
    return "123456"

and replaced nn.Module with Model in all type hints (was it on purpose that you replaced only one instance)?

The model should, now, be considered the same.

First execution:

Screenshot 2023-11-01 at 16 09 05

Relaunch with changed param:

Screenshot 2023-11-01 at 16 11 08

Did I make any mistake in this experiment?


I agree that if this mechanism worked and allowed to "pass offloaded objects through dynamic tasks", it would be a good solution.

Hashing torch models (including architecture, forward method, state dict/parameters) is not a trivial problem but that is another story and this is a problem that affects all offloaded objects in dynamic workflows ...

Copy link

Hello 👋, this issue has been inactive for over 9 months. To help maintain a clean and focused backlog, we'll be marking this issue as stale and will engage on it to decide if it is still applicable.
Thank you for your contribution and understanding! 🙏

@github-actions github-actions bot added the stale label Jul 29, 2024
@alekpikl
Copy link

Still an issue for us. Since the issue is 9 months old, are there any news? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backlogged For internal use. Reserved for contributor team workflow. bug Something isn't working flytekit FlyteKit Python related issue stale
Projects
None yet
Development

No branches or pull requests

6 participants