Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange behavior of saving sharded trainstate in GCP. #660

Open
chiamp opened this issue Jan 11, 2024 · 3 comments
Open

Strange behavior of saving sharded trainstate in GCP. #660

chiamp opened this issue Jan 11, 2024 · 3 comments

Comments

@chiamp
Copy link

chiamp commented Jan 11, 2024

A user posted in the Flax discussions about an orbax discrepancy between different zones in GCE. Do different zones have different orbax versions?

==================================================================

what happened

When I save my sharded state in asia-northeast3-a in GCE with orbax, the orbax create /tmp/orbax_ckpt/0/_sharding file which starts with

{"dropout_rng":"{\"sharding_type\": \"NamedSharding\", \"shape\": [2, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}","opt_state.0.0.count":"{\"sharding_type\": \"NamedSharding\", \"shape\": [2, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}",
...

My sharded state has "dropout_rng" state, so above file make sense.

However, when I run same script in other region like asia-southeast1-b, the orbax create _sharding file without proper layer names, for example,

{"ZHJvcG91dF9ybmc=":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}","b3B0X3N0YXRlLjAuMC5jb3VudA==":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}",
...

Theory

I doubt that this is related to OCDBT, because the only difference in between terminal outputs is ocdbt is intitialized in asia-northeast3-a but the other regions are not having this message
type_handlers.py:223] OCDBT is initialized successfully..

I checked tensorstore==0.1.51 in all region.

Anyone can help me please?

Thank you.

Originally posted by @sw32-seo in google/flax#3538

@niketkumar
Copy link
Collaborator

@sw32-seo Can you please check Orbax version in all regions?

@sw32-seo
Copy link

sw32-seo commented Jan 16, 2024

@niketkumar I used same docker image which has orbax-checkpoint==0.4.7 for all regions.

@liangyaning33
Copy link
Contributor

Hi @chiamp, thanks for raising the issue. We had to update the pytree key names encoded since special characters like '~' were not encoded properly with the previous encoding.

The names are still proper since it now is under the base64 urlsafe_encode. I'm just curious if you could re-run ur script in asia-northeast3-a and see if the sharding file updates? My suspection is that it will, and will be the same as the one you provided below.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants