-
Notifications
You must be signed in to change notification settings - Fork 684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Dynamic workflows break caching of artifacts offloaded to blob storage #3842
Comments
I've got the same problem. I raised it on the flyte slack so there is a bit of relevant discussion there https://flyte-org.slack.com/archives/CP2HDHKE1/p1697208167873249 |
Reason for the bug - #4246. Lets track that as the master issue |
I kinda feel like this issue is a bit different than #4246, but maybe I just haven't dug deeply enough - I haven't traced through all the links in that other issue. @fg91 did you try using import typing
import torch.nn as nn
from typing_extensions import Annotated
from flytekit import task, workflow, dynamic, HashMethod
def hash_model(m: nn.Module) -> str:
return str(hash(m))
Model = Annotated[nn.Module, HashMethod(hash_model)]
@task(cache=True, cache_version="0.1")
def train(model: nn.Module) -> nn.Module:
print(f"Training model {model}")
return model
@task(cache=True, cache_version="0.1")
def other_task(param: int) -> int:
print(f"Doing something else with param {param}")
return param
@dynamic(cache=True, cache_version="0.1")
def sub_wf(model: nn.Module, param: int) -> typing.Tuple[nn.Module, int]:
other_task(param=param)
train(model=model)
return model, param
@task(cache=True, cache_version="0.1")
def create_model() -> Model:
return nn.Linear(1, 1)
@workflow
def wf(param: int = 1):
model = create_model()
sub_wf(model=model, param=param) You'd have to replace the hash with something real of course. The python hash function has too much entropy in it. I'm not sure if I agree with the goal actually. Somehow detect that it's passed through and not upload again. How would flytekit distinguish between these two # What you had above
@dynamic(cache=True, cache_version="0.1")
def sub_wf(model: nn.Module, param: int) -> tuple[nn.Module, int]:
other_task(param=param)
train(model=model)
return model, param @dynamic(cache=True, cache_version="0.1")
def sub_wf(model: nn.Module, param: int) -> tuple[nn.Module, int]:
model = update_something(model) # make a change to the model
other_task(param=param)
train(model=model)
return model, param With the HashMethod change, what should happen is yes, the model does get downloaded, but it should get uploaded again with that same hash, so you do still incur the download/rehash/upload. But the subsequent downstream The difference I feel vs the other ticket is that in the other ticket:
As always thank you for the screenshots and very clear explanation. Super helpful. |
I think you are right @wild-endeavor. Sorry for confusing things. So probably the solution to this issue is to use |
Hey @wild-endeavor, sorry for the late reply. I wasn't aware of the To test whether the proposal works and for simplicity, I forced the same hash def hash_model(m: nn.Module) -> str:
return "123456" and replaced The model should, now, be considered the same. First execution: Relaunch with changed Did I make any mistake in this experiment? I agree that if this mechanism worked and allowed to "pass offloaded objects through dynamic tasks", it would be a good solution. Hashing torch models (including architecture, forward method, state dict/parameters) is not a trivial problem but that is another story and this is a problem that affects all offloaded objects in dynamic workflows ... |
Hello 👋, this issue has been inactive for over 9 months. To help maintain a clean and focused backlog, we'll be marking this issue as stale and will engage on it to decide if it is still applicable. |
Still an issue for us. Since the issue is 9 months old, are there any news? Thanks. |
Describe the bug
Some type transformers offload artifacts to blob storage as follows (e.g. here):
When objects of such types are passed to
@dynamic
workflows and then passed along to tasks called within the dynamic workflow, this behaviour always leads to cache misses. The reason is that in the dynamic workflow, the objects are deserialized and then again serialized to a different random remote path.Expected behavior
There should not be cache misses in this situation.
Additional context to reproduce
Let us consider this example workflow:
Screenshots
The first execution with
param=1
results in cache puts for all tasks:Next, let's re-run the workflow but with
param=2
:It is expected that
other_task
has a cache miss since we changedparam
. However, sincecreate_model
had a cache hit, so should havetrain
.Instead, one can observe that the output of
create_model
(retrieved from cache) ...... is not the same is as the input to
train
:The reason is that the respective type transformer deserialized the artifact and serialized it again to a new random bucket path.
I don't know how exactly this could be solved but the type engine should "somehow detect that it is being executed in a dynamic workflow" and should "simply forward the artifacts instead of serializing them to a different location".
Are you sure this issue hasn't been raised already?
Have you read the Code of Conduct?
The text was updated successfully, but these errors were encountered: